]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Drop code duplication in date/time parsing
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 6 Jun 2021 23:25:20 +0000 (00:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 6 Jun 2021 23:38:46 +0000 (00:38 +0100)
psycopg3_c/psycopg3_c/types/date.pyx

index ca742c5de97f707c4e885b4e3bca4e06137c0cf4..fabb7c4f2465042401fc6e9955a73e4ec5dfa6da 100644 (file)
@@ -4,7 +4,7 @@ Cython adapters for date/time types.
 
 # Copyright (C) 2021 The Psycopg Team
 
-from libc.string cimport strchr
+from libc.string cimport memset, strchr
 from cpython cimport datetime as cdt
 from cpython.dict cimport PyDict_GetItem
 from cpython.object cimport PyObject, PyObject_CallFunctionObjArgs
@@ -386,19 +386,15 @@ cdef class DateLoader(CLoader):
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"date not supported: {s!r}")
 
-        cdef int vals[3]
-        vals[0] = vals[1] = vals[2] = 0
+        DEF NVALUES = 3
+        cdef int vals[NVALUES]
+        memset(vals, 0, sizeof(vals))
 
-        cdef size_t i
-        cdef int ival = 0
-        for i in range(length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-            else:
-                ival += 1
-                if ival >= 3:
-                    s = bytes(data).decode("utf8", "replace")
-                    raise e.DataError(f"can't parse date {s!r}")
+        cdef const char *ptr
+        ptr = _parse_date_values(data, vals, NVALUES)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse date {s!r}")
 
         try:
             if self._order == ORDER_YMD:
@@ -437,45 +433,27 @@ cdef class TimeLoader(CLoader):
 
     cdef object cload(self, const char *data, size_t length):
 
-        DEF HO = 0
-        DEF MI = 1
-        DEF SE = 2
-        DEF MS = 3
-        cdef int vals[4]
-        vals[HO] = vals[MI] = vals[SE] = vals[MS] = 0
+        DEF NVALUES = 3
+        cdef int vals[NVALUES]
+        memset(vals, 0, sizeof(vals))
+        cdef const char *ptr
 
         # Parse the first 3 groups of digits
-        cdef size_t i
-        cdef int ival = HO
-        for i in range(length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-            else:
-                ival += 1
-                if ival >= MS:
-                    break
-
-        # Parse the 4th group of digits. Count the digits parsed
-        cdef int msdigits = 0
-        for i in range(i + 1, length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-                msdigits += 1
-            else:
-                s = bytes(data).decode("utf8", "replace")
-                raise e.DataError(f"can't parse time {s!r}")
+        ptr = _parse_date_values(data, vals, NVALUES)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse time {s!r}")
 
-        # Pad the fraction of second to get millis
-        if vals[MS]:
-            while msdigits < 6:
-                vals[MS] *= 10
-                msdigits += 1
+        # Parse the microseconds
+        cdef int us = 0
+        if ptr[0] == b".":
+            ptr = _parse_micros(ptr + 1, &us)
 
         try:
-            return cdt.time_new(vals[HO], vals[MI], vals[SE], vals[MS], None)
+            return cdt.time_new(vals[0], vals[1], vals[2], us, None)
         except ValueError as ex:
             s = bytes(data).decode("utf8", "replace")
-            raise e.DataError(f"can't parse date {s!r}: {ex}") from None
+            raise e.DataError(f"can't parse time {s!r}: {ex}") from None
 
 
 @cython.final
@@ -512,68 +490,29 @@ cdef class TimetzLoader(CLoader):
 
     cdef object cload(self, const char *data, size_t length):
 
-        DEF HO = 0
-        DEF MI = 1
-        DEF SE = 2
-        DEF MS = 3
-        DEF OH = 4
-        DEF OM = 5
-        DEF OS = 6
-        cdef int vals[7]
-        vals[HO] = vals[MI] = vals[SE] = vals[MS] = 0
-        vals[OH] = vals[OM] = vals[OS] = 0
+        cdef int vals[3]
+        memset(vals, 0, sizeof(vals))
+        cdef const char *ptr
 
-        # Parse the first 3 groups of digits
-        cdef size_t i = 0
-        cdef int ival = HO
-        for i in range(length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-            else:
-                ival += 1
-                if ival >= MS:
-                    break
-
-        # Are there millis?
-        cdef int msdigits = 0
-        if data[i] == b".":
-            # Parse the 4th group of digits. Count the digits parsed
-            for i in range(i + 1, length):
-                if b'0' <= data[i] <= b'9':
-                    vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-                    msdigits += 1
-                else:
-                    ival += 1
-                    break
-
-            # Pad the fraction of second to get millis
-            while msdigits < 6:
-                vals[MS] *= 10
-                msdigits += 1
+        # Parse the first 3 groups of digits (time)
+        ptr = _parse_date_values(data, vals, 3)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse timetz {s!r}")
+
+        # Parse the microseconds
+        cdef int us = 0
+        if ptr[0] == b".":
+            ptr = _parse_micros(ptr + 1, &us)
 
         # Parse the timezone
-        cdef char sgn = data[i]
-        ival = OH
-        for i in range(i + 1, length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-            else:
-                ival += 1
-                if ival > OS:
-                    s = bytes(data).decode("utf8", "replace")
-                    raise e.DataError(f"can't parse timetz {s!r}")
-
-        # Calculate timezone
-        cdef int off = 60 * (60 * vals[OH] + vals[OM])
-        # Python < 3.7 didn't support seconds in the timezones
-        if PY_VERSION_HEX >= 0x03070000:
-            off += vals[OS]
-        if sgn == b"-":
-            off = -off
-
-        tz = timezone_from_seconds(off)
+        tz = _parse_timezone(&ptr)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse timetz {s!r}")
+
         try:
-            return cdt.time_new(vals[HO], vals[MI], vals[SE], vals[MS], tz)
+            return cdt.time_new(vals[0], vals[1], vals[2], us, tz)
         except ValueError as ex:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timetz {s!r}: {ex}") from None
@@ -636,7 +575,7 @@ cdef class TimestampLoader(CLoader):
     cdef object cload(self, const char *data, size_t length):
         if data[length - 1] == b'C':  # ends with BC
             s = bytes(data).decode("utf8", "replace")
-            raise e.DataError(f"BC timestamps not supported, got {s!r}")
+            raise e.DataError(f"BC timestamp not supported, got {s!r}")
 
         if self._order == ORDER_PGDM or self._order == ORDER_PGMD:
             return self._cload_pg(data, length)
@@ -647,57 +586,33 @@ cdef class TimestampLoader(CLoader):
         DEF HO = 3
         DEF MI = 4
         DEF SE = 5
-        DEF MS = 6
-        cdef int vals[7]
-        vals[D1] = vals[D2] = vals[D3] = 0
-        vals[HO] = vals[MI] = vals[SE] = vals[MS] = 0
-
-        # Parse the first 6 groups of digits
-        cdef size_t i
-        cdef int ival = D1
-        for i in range(length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-            else:
-                ival += 1
-                if ival >= MS:
-                    break
-
-        # Parse the 7th group of digits. Count the digits parsed
-        cdef int msdigits = 0
-        if data[i] == b'.':
-            for i in range(i + 1, length):
-                if b'0' <= data[i] <= b'9':
-                    vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-                    msdigits += 1
-                else:
-                    s = bytes(data).decode("utf8", "replace")
-                    raise e.DataError(f"can't parse timestamp {s!r}")
-
-            # Pad the fraction of second to get millis
-            if vals[MS]:
-                while msdigits < 6:
-                    vals[MS] *= 10
-                    msdigits += 1
+        cdef int vals[6]
+        memset(vals, 0, sizeof(vals))
+        cdef const char *ptr
+
+        # Parse the first 6 groups of digits (date and time)
+        ptr = _parse_date_values(data, vals, 6)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse timetz {s!r}")
+
+        # Parse the microseconds
+        cdef int us = 0
+        if ptr[0] == b".":
+            ptr = _parse_micros(ptr + 1, &us)
 
         # Resolve the YMD order
         cdef int y, m, d
         if self._order == ORDER_YMD:
-            y = vals[D1]
-            m = vals[D2]
-            d = vals[D3]
+            y, m, d = vals[D1], vals[D2], vals[D3]
         elif self._order == ORDER_DMY:
-            d = vals[D1]
-            m = vals[D2]
-            y = vals[D3]
+            d, m, y = vals[D1], vals[D2], vals[D3]
         else: # self._order == ORDER_MDY
-            m = vals[D1]
-            d = vals[D2]
-            y = vals[D3]
+            m, d, y = vals[D1], vals[D2], vals[D3]
 
         try:
             return cdt.datetime_new(
-                y, m, d, vals[HO], vals[MI], vals[SE], vals[MS], None)
+                y, m, d, vals[HO], vals[MI], vals[SE], us, None)
         except ValueError as ex:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamp {s!r}: {ex}") from None
@@ -706,10 +621,11 @@ cdef class TimestampLoader(CLoader):
         DEF HO = 0
         DEF MI = 1
         DEF SE = 2
-        DEF MS = 3
-        DEF YE = 4
-        cdef int vals[5]
-        vals[HO] = vals[MI] = vals[SE] = vals[MS] = vals[YE] = 0
+        DEF YE = 3
+        DEF NVALS = 4
+        cdef int vals[NVALS]
+        memset(vals, 0, sizeof(vals))
+        cdef const char *ptr
 
         # Find Wed Jun 02 or Wed 02 Jun
         cdef char *seps[3]
@@ -720,39 +636,22 @@ cdef class TimestampLoader(CLoader):
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamp {s!r}")
 
-        # Parse the following 3 groups of digits
-        cdef size_t i = seps[2] - data
-        cdef int ival = HO
-        for i in range(i + 1, length):
-            if b'0' <= data[i] <= b'9':
-                vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-            else:
-                ival += 1
-                if ival >= MS:
-                    break
-
-        # Parse the ms group of digits. Count the digits parsed
-        cdef int msdigits = 0
-        if data[i] == b'.':
-            for i in range(i + 1, length):
-                if b'0' <= data[i] <= b'9':
-                    vals[ival] = vals[ival] * 10 + (data[i] - <char>b'0')
-                    msdigits += 1
-                else:
-                    break
-
-            # Pad the fraction of second to get millis
-            if vals[MS]:
-                while msdigits < 6:
-                    vals[MS] *= 10
-                    msdigits += 1
+        # Parse the following 3 groups of digits (time)
+        ptr = _parse_date_values(seps[2] + 1, vals, 3)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse timestamp {s!r}")
+
+        # Parse the microseconds
+        cdef int us = 0
+        if ptr[0] == b".":
+            ptr = _parse_micros(ptr + 1, &us)
 
         # Parse the year
-        for i in range(i + 1, length):
-            if b'0' <= data[i] <= b'9':
-                vals[YE] = vals[YE] * 10 + (data[i] - <char>b'0')
-            else:
-                break
+        ptr = _parse_date_values(ptr + 1, vals + 3, 1)
+        if ptr == NULL:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse timestamp {s!r}")
 
         # Resolve the MD order
         cdef int m, d
@@ -769,7 +668,7 @@ cdef class TimestampLoader(CLoader):
 
         try:
             return cdt.datetime_new(
-                vals[YE], m, d, vals[HO], vals[MI], vals[SE], vals[MS], None)
+                vals[YE], m, d, vals[HO], vals[MI], vals[SE], us, None)
         except ValueError as ex:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamp {s!r}: {ex}") from None
@@ -811,6 +710,88 @@ cdef class TimestampBinaryLoader(CLoader):
                 ) from None
 
 
+cdef const char *_parse_date_values(const char *ptr, int *vals, int nvals):
+    """
+    Parse *nvals* numeric values separated by non-numeric chars.
+
+    Write the result in the *vals* array (assumed zeroed) starting from *start*.
+
+    Return the pointer at the separator after the final digit.
+    """
+    cdef int ival = 0
+    while ptr[0]:
+        if b'0' <= ptr[0] <= b'9':
+            vals[ival] = vals[ival] * 10 + (ptr[0] - <char>b'0')
+        else:
+            ival += 1
+            if ival >= nvals:
+                break
+
+        ptr += 1
+
+    return ptr
+
+
+cdef const char *_parse_micros(const char *ptr, int *us):
+    """
+    Parse microseconds from a string.
+
+    Micros are assumed up to 6 digit chars separated by a non-digit.
+
+    Return the pointer at the separator after the final digit.
+    """
+    cdef int ndigits = 0
+    while ptr[0]:
+        if b'0' <= ptr[0] <= b'9':
+            us[0] = us[0] * 10 + (ptr[0] - <char>b'0')
+            ndigits += 1
+        else:
+            break
+
+        ptr += 1
+
+    # Pad the fraction of second to get millis
+    if us[0]:
+        while ndigits < 6:
+            us[0] *= 10
+            ndigits += 1
+
+    return ptr
+
+
+cdef object _parse_timezone(const char **bufptr):
+    """
+    Parse a timezone from a string, return Python timezone object.
+
+    Modify the buffer pointer to point at the first character after the
+    timezone parsed. In case of parse error make it NULL.
+    """
+    cdef const char *ptr = bufptr[0]
+    cdef char sgn = ptr[0]
+
+    # Parse at most three digits
+    DEF OH = 0
+    DEF OM = 1
+    DEF OS = 2
+    DEF NVALS = 3
+    cdef int vals[NVALS]
+    memset(vals, 0, sizeof(vals))
+
+    ptr = _parse_date_values(ptr + 1, vals, NVALS)
+    if ptr == NULL:
+        return None
+
+    # Calculate timezone
+    cdef int off = 60 * (60 * vals[OH] + vals[OM])
+    # Python < 3.7 didn't support seconds in the timezones
+    if PY_VERSION_HEX >= 0x03070000:
+        off += vals[OS]
+    if sgn == b"-":
+        off = -off
+
+    return timezone_from_seconds(off)
+
+
 cdef object timezone_from_seconds(int sec, __cache={}):
     cdef object pysec = sec
     cdef PyObject *ptr = PyDict_GetItem(__cache, pysec)