]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix interval and timezone parsing in copy buffers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 23 Jun 2021 11:45:03 +0000 (12:45 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:15:34 +0000 (16:15 +0100)
psycopg3_c/psycopg3_c/types/date.pyx
tests/types/test_date.py

index bb55e68435e43b807260341f4b10a969d3ed1b61..6d7361b74d2e90a9100a45877c9d4ea2e3b58494 100644 (file)
@@ -399,7 +399,8 @@ cdef class DateLoader(CLoader):
         memset(vals, 0, sizeof(vals))
 
         cdef const char *ptr
-        ptr = _parse_date_values(data, vals, NVALUES)
+        cdef const char *end = data + length
+        ptr = _parse_date_values(data, end, vals, NVALUES)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse date {s!r}")
@@ -445,9 +446,10 @@ cdef class TimeLoader(CLoader):
         cdef int vals[NVALUES]
         memset(vals, 0, sizeof(vals))
         cdef const char *ptr
+        cdef const char *end = data + length
 
         # Parse the first 3 groups of digits
-        ptr = _parse_date_values(data, vals, NVALUES)
+        ptr = _parse_date_values(data, end, vals, NVALUES)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse time {s!r}")
@@ -501,9 +503,10 @@ cdef class TimetzLoader(CLoader):
         cdef int vals[3]
         memset(vals, 0, sizeof(vals))
         cdef const char *ptr
+        cdef const char *end = data + length
 
         # Parse the first 3 groups of digits (time)
-        ptr = _parse_date_values(data, vals, 3)
+        ptr = _parse_date_values(data, end, vals, 3)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timetz {s!r}")
@@ -514,7 +517,7 @@ cdef class TimetzLoader(CLoader):
             ptr = _parse_micros(ptr + 1, &us)
 
         # Parse the timezone
-        cdef int offsecs = _parse_timezone_to_seconds(&ptr)
+        cdef int offsecs = _parse_timezone_to_seconds(&ptr, end)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timetz {s!r}")
@@ -586,12 +589,13 @@ cdef class TimestampLoader(CLoader):
             raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
 
     cdef object cload(self, const char *data, size_t length):
-        if data[length - 1] == b'C':  # ends with BC
+        cdef const char *end = data + length
+        if end[-1] == b'C':  # ends with BC
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"BC timestamp not supported, got {s!r}")
 
         if self._order == ORDER_PGDM or self._order == ORDER_PGMD:
-            return self._cload_pg(data, length)
+            return self._cload_pg(data, end)
 
         DEF D1 = 0
         DEF D2 = 1
@@ -604,7 +608,7 @@ cdef class TimestampLoader(CLoader):
         cdef const char *ptr
 
         # Parse the first 6 groups of digits (date and time)
-        ptr = _parse_date_values(data, vals, 6)
+        ptr = _parse_date_values(data, end, vals, 6)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timetz {s!r}")
@@ -630,7 +634,7 @@ cdef class TimestampLoader(CLoader):
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamp {s!r}: {ex}") from None
 
-    cdef object _cload_pg(self, const char *data, size_t length):
+    cdef object _cload_pg(self, const char *data, const char *end):
         DEF HO = 0
         DEF MI = 1
         DEF SE = 2
@@ -650,7 +654,7 @@ cdef class TimestampLoader(CLoader):
             raise e.DataError(f"can't parse timestamp {s!r}")
 
         # Parse the following 3 groups of digits (time)
-        ptr = _parse_date_values(seps[2] + 1, vals, 3)
+        ptr = _parse_date_values(seps[2] + 1, end, vals, 3)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamp {s!r}")
@@ -661,7 +665,7 @@ cdef class TimestampLoader(CLoader):
             ptr = _parse_micros(ptr + 1, &us)
 
         # Parse the year
-        ptr = _parse_date_values(ptr + 1, vals + 3, 1)
+        ptr = _parse_date_values(ptr + 1, end, vals + 3, 1)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamp {s!r}")
@@ -750,7 +754,8 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader):
             self._order = ORDER_DMY
 
     cdef object cload(self, const char *data, size_t length):
-        if data[length - 1] == b'C':  # ends with BC
+        cdef const char *end = data + length
+        if end[-1] == b'C':  # ends with BC
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"BC timestamptz not supported, got {s!r}")
 
@@ -768,7 +773,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader):
 
         # Parse the first 6 groups of digits (date and time)
         cdef const char *ptr
-        ptr = _parse_date_values(data, vals, 6)
+        ptr = _parse_date_values(data, end, vals, 6)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamptz {s!r}")
@@ -788,7 +793,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader):
             m, d, y = vals[D1], vals[D2], vals[D3]
 
         # Parse the timezone
-        cdef int offsecs = _parse_timezone_to_seconds(&ptr)
+        cdef int offsecs = _parse_timezone_to_seconds(&ptr, end)
         if ptr == NULL:
             s = bytes(data).decode("utf8", "replace")
             raise e.DataError(f"can't parse timestamptz {s!r}")
@@ -811,7 +816,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader):
             raise e.DataError(f"can't parse timestamptz {s!r}: {ex}") from None
 
     cdef object _cload_notimpl(self, const char *data, size_t length):
-        s = bytes(data).decode("utf8", "replace")
+        s = bytes(data)[:length].decode("utf8", "replace")
         ds = _get_datestyle(self.connection).decode("ascii")
         raise NotImplementedError(
             f"can't parse timestamptz with DateStyle {ds!r}: {s!r}"
@@ -883,6 +888,7 @@ cdef class IntervalLoader(CLoader):
         cdef int val
         cdef const char *ptr = data
         cdef const char *sep
+        cdef const char *end = ptr + length
 
         # If there are spaces, there is a [+|-]n [days|months|years]
         while True:
@@ -893,11 +899,11 @@ cdef class IntervalLoader(CLoader):
                 sign = 0
 
             sep = strchr(ptr, b' ')
-            if sep == NULL:
+            if sep == NULL or sep > end:
                 break
 
             val = 0
-            ptr = _parse_date_values(ptr, &val, 1)
+            ptr = _parse_date_values(ptr, end, &val, 1)
             if ptr == NULL:
                 s = bytes(data).decode("utf8", "replace")
                 raise e.DataError(f"can't parse interval {s!r}")
@@ -917,7 +923,7 @@ cdef class IntervalLoader(CLoader):
 
             # Skip the date part word.
             ptr = strchr(ptr + 1, b' ')
-            if ptr != NULL:
+            if ptr != NULL and ptr < end:
                 ptr += 1
             else:
                 break
@@ -927,7 +933,7 @@ cdef class IntervalLoader(CLoader):
         cdef int vals[NVALS]
         memset(vals, 0, sizeof(vals))
         if ptr != NULL:
-            ptr = _parse_date_values(ptr, vals, NVALS)
+            ptr = _parse_date_values(ptr, end, vals, NVALS)
             if ptr == NULL:
                 s = bytes(data).decode("utf8", "replace")
                 raise e.DataError(f"can't parse interval {s!r}")
@@ -1000,7 +1006,9 @@ cdef class IntervalBinaryLoader(CLoader):
         return cdt.timedelta_new(days + usdays, ussecs, us)
 
 
-cdef const char *_parse_date_values(const char *ptr, int *vals, int nvals):
+cdef const char *_parse_date_values(
+    const char *ptr, const char *end, int *vals, int nvals
+):
     """
     Parse *nvals* numeric values separated by non-numeric chars.
 
@@ -1009,7 +1017,7 @@ cdef const char *_parse_date_values(const char *ptr, int *vals, int nvals):
     Return the pointer at the separator after the final digit.
     """
     cdef int ival = 0
-    while ptr[0]:
+    while ptr < end:
         if b'0' <= ptr[0] <= b'9':
             vals[ival] = vals[ival] * 10 + (ptr[0] - <char>b'0')
         else:
@@ -1046,7 +1054,7 @@ cdef const char *_parse_micros(const char *start, int *us):
     return ptr
 
 
-cdef int _parse_timezone_to_seconds(const char **bufptr):
+cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end):
     """
     Parse a timezone from a string, return Python timezone object.
 
@@ -1064,7 +1072,7 @@ cdef int _parse_timezone_to_seconds(const char **bufptr):
     cdef int vals[NVALS]
     memset(vals, 0, sizeof(vals))
 
-    ptr = _parse_date_values(ptr + 1, vals, NVALS)
+    ptr = _parse_date_values(ptr + 1, end, vals, NVALS)
     if ptr == NULL:
         return 0
 
index 3a969e4c93248e488fbc6020747ae6e9277a30d4..b58302115817a0e26517017c4a35a51af5e979eb 100644 (file)
@@ -337,6 +337,25 @@ class TestDateTimeTz:
         assert rec[0] is True, type
         assert rec[1] == val
 
+    def test_load_copy(self, conn):
+        cur = conn.cursor(binary=False)
+        with cur.copy(
+            """
+            copy (
+                select
+                    '2000-01-01 01:02:03.123456-10:20'::timestamptz,
+                    '11111111'::int4
+            ) to stdout
+            """
+        ) as copy:
+            copy.set_types(["timestamptz", "int4"])
+            rec = copy.read_row()
+
+        tz = dt.timezone(-dt.timedelta(hours=10, minutes=20))
+        want = dt.datetime(2000, 1, 1, 1, 2, 3, 123456, tzinfo=tz)
+        assert rec[0] == want
+        assert rec[1] == 11111111
+
 
 class TestTime:
     @pytest.mark.parametrize(
@@ -456,6 +475,25 @@ class TestTimeTz:
         assert rec[0] is True, type
         assert rec[1] == val
 
+    def test_load_copy(self, conn):
+        cur = conn.cursor(binary=False)
+        with cur.copy(
+            """
+            copy (
+                select
+                    '01:02:03.123456-10:20'::timetz,
+                    '11111111'::int4
+            ) to stdout
+            """
+        ) as copy:
+            copy.set_types(["timetz", "int4"])
+            rec = copy.read_row()
+
+        tz = dt.timezone(-dt.timedelta(hours=10, minutes=20))
+        want = dt.time(1, 2, 3, 123456, tzinfo=tz)
+        assert rec[0] == want
+        assert rec[1] == 11111111
+
 
 class TestInterval:
     dump_timedelta_samples = [
@@ -574,6 +612,24 @@ class TestInterval:
         ).fetchone()
         assert rec == (dt.date(2020, 12, 31), dt.date(9999, 12, 31))
 
+    def test_load_copy(self, conn):
+        cur = conn.cursor(binary=False)
+        with cur.copy(
+            """
+            copy (
+                select
+                    '1 days +00:00:01.000001'::interval,
+                    'foo bar'::text
+            ) to stdout
+            """
+        ) as copy:
+            copy.set_types(["interval", "text"])
+            rec = copy.read_row()
+
+        want = dt.timedelta(days=1, seconds=1, microseconds=1)
+        assert rec[0] == want
+        assert rec[1] == "foo bar"
+
 
 #
 # Support