From 143dbf33059e5d36272545b828e0866e5b01a5c8 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 23 Jun 2021 12:45:03 +0100 Subject: [PATCH] Fix interval and timezone parsing in copy buffers --- psycopg3_c/psycopg3_c/types/date.pyx | 52 +++++++++++++++----------- tests/types/test_date.py | 56 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/psycopg3_c/psycopg3_c/types/date.pyx b/psycopg3_c/psycopg3_c/types/date.pyx index bb55e6843..6d7361b74 100644 --- a/psycopg3_c/psycopg3_c/types/date.pyx +++ b/psycopg3_c/psycopg3_c/types/date.pyx @@ -399,7 +399,8 @@ cdef class DateLoader(CLoader): memset(vals, 0, sizeof(vals)) cdef const char *ptr - ptr = _parse_date_values(data, vals, NVALUES) + cdef const char *end = data + length + ptr = _parse_date_values(data, end, vals, NVALUES) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse date {s!r}") @@ -445,9 +446,10 @@ cdef class TimeLoader(CLoader): cdef int vals[NVALUES] memset(vals, 0, sizeof(vals)) cdef const char *ptr + cdef const char *end = data + length # Parse the first 3 groups of digits - ptr = _parse_date_values(data, vals, NVALUES) + ptr = _parse_date_values(data, end, vals, NVALUES) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse time {s!r}") @@ -501,9 +503,10 @@ cdef class TimetzLoader(CLoader): cdef int vals[3] memset(vals, 0, sizeof(vals)) cdef const char *ptr + cdef const char *end = data + length # Parse the first 3 groups of digits (time) - ptr = _parse_date_values(data, vals, 3) + ptr = _parse_date_values(data, end, vals, 3) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timetz {s!r}") @@ -514,7 +517,7 @@ cdef class TimetzLoader(CLoader): ptr = _parse_micros(ptr + 1, &us) # Parse the timezone - cdef int offsecs = _parse_timezone_to_seconds(&ptr) + cdef int offsecs = _parse_timezone_to_seconds(&ptr, end) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timetz {s!r}") @@ -586,12 +589,13 @@ cdef class TimestampLoader(CLoader): raise e.InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") cdef object cload(self, const char *data, size_t length): - if data[length - 1] == b'C': # ends with BC + cdef const char *end = data + length + if end[-1] == b'C': # ends with BC s = bytes(data).decode("utf8", "replace") raise e.DataError(f"BC timestamp not supported, got {s!r}") if self._order == ORDER_PGDM or self._order == ORDER_PGMD: - return self._cload_pg(data, length) + return self._cload_pg(data, end) DEF D1 = 0 DEF D2 = 1 @@ -604,7 +608,7 @@ cdef class TimestampLoader(CLoader): cdef const char *ptr # Parse the first 6 groups of digits (date and time) - ptr = _parse_date_values(data, vals, 6) + ptr = _parse_date_values(data, end, vals, 6) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timetz {s!r}") @@ -630,7 +634,7 @@ cdef class TimestampLoader(CLoader): s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timestamp {s!r}: {ex}") from None - cdef object _cload_pg(self, const char *data, size_t length): + cdef object _cload_pg(self, const char *data, const char *end): DEF HO = 0 DEF MI = 1 DEF SE = 2 @@ -650,7 +654,7 @@ cdef class TimestampLoader(CLoader): raise e.DataError(f"can't parse timestamp {s!r}") # Parse the following 3 groups of digits (time) - ptr = _parse_date_values(seps[2] + 1, vals, 3) + ptr = _parse_date_values(seps[2] + 1, end, vals, 3) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timestamp {s!r}") @@ -661,7 +665,7 @@ cdef class TimestampLoader(CLoader): ptr = _parse_micros(ptr + 1, &us) # Parse the year - ptr = _parse_date_values(ptr + 1, vals + 3, 1) + ptr = _parse_date_values(ptr + 1, end, vals + 3, 1) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timestamp {s!r}") @@ -750,7 +754,8 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader): self._order = ORDER_DMY cdef object cload(self, const char *data, size_t length): - if data[length - 1] == b'C': # ends with BC + cdef const char *end = data + length + if end[-1] == b'C': # ends with BC s = bytes(data).decode("utf8", "replace") raise e.DataError(f"BC timestamptz not supported, got {s!r}") @@ -768,7 +773,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader): # Parse the first 6 groups of digits (date and time) cdef const char *ptr - ptr = _parse_date_values(data, vals, 6) + ptr = _parse_date_values(data, end, vals, 6) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timestamptz {s!r}") @@ -788,7 +793,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader): m, d, y = vals[D1], vals[D2], vals[D3] # Parse the timezone - cdef int offsecs = _parse_timezone_to_seconds(&ptr) + cdef int offsecs = _parse_timezone_to_seconds(&ptr, end) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse timestamptz {s!r}") @@ -811,7 +816,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader): raise e.DataError(f"can't parse timestamptz {s!r}: {ex}") from None cdef object _cload_notimpl(self, const char *data, size_t length): - s = bytes(data).decode("utf8", "replace") + s = bytes(data)[:length].decode("utf8", "replace") ds = _get_datestyle(self.connection).decode("ascii") raise NotImplementedError( f"can't parse timestamptz with DateStyle {ds!r}: {s!r}" @@ -883,6 +888,7 @@ cdef class IntervalLoader(CLoader): cdef int val cdef const char *ptr = data cdef const char *sep + cdef const char *end = ptr + length # If there are spaces, there is a [+|-]n [days|months|years] while True: @@ -893,11 +899,11 @@ cdef class IntervalLoader(CLoader): sign = 0 sep = strchr(ptr, b' ') - if sep == NULL: + if sep == NULL or sep > end: break val = 0 - ptr = _parse_date_values(ptr, &val, 1) + ptr = _parse_date_values(ptr, end, &val, 1) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse interval {s!r}") @@ -917,7 +923,7 @@ cdef class IntervalLoader(CLoader): # Skip the date part word. ptr = strchr(ptr + 1, b' ') - if ptr != NULL: + if ptr != NULL and ptr < end: ptr += 1 else: break @@ -927,7 +933,7 @@ cdef class IntervalLoader(CLoader): cdef int vals[NVALS] memset(vals, 0, sizeof(vals)) if ptr != NULL: - ptr = _parse_date_values(ptr, vals, NVALS) + ptr = _parse_date_values(ptr, end, vals, NVALS) if ptr == NULL: s = bytes(data).decode("utf8", "replace") raise e.DataError(f"can't parse interval {s!r}") @@ -1000,7 +1006,9 @@ cdef class IntervalBinaryLoader(CLoader): return cdt.timedelta_new(days + usdays, ussecs, us) -cdef const char *_parse_date_values(const char *ptr, int *vals, int nvals): +cdef const char *_parse_date_values( + const char *ptr, const char *end, int *vals, int nvals +): """ Parse *nvals* numeric values separated by non-numeric chars. @@ -1009,7 +1017,7 @@ cdef const char *_parse_date_values(const char *ptr, int *vals, int nvals): Return the pointer at the separator after the final digit. """ cdef int ival = 0 - while ptr[0]: + while ptr < end: if b'0' <= ptr[0] <= b'9': vals[ival] = vals[ival] * 10 + (ptr[0] - b'0') else: @@ -1046,7 +1054,7 @@ cdef const char *_parse_micros(const char *start, int *us): return ptr -cdef int _parse_timezone_to_seconds(const char **bufptr): +cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end): """ Parse a timezone from a string, return Python timezone object. @@ -1064,7 +1072,7 @@ cdef int _parse_timezone_to_seconds(const char **bufptr): cdef int vals[NVALS] memset(vals, 0, sizeof(vals)) - ptr = _parse_date_values(ptr + 1, vals, NVALS) + ptr = _parse_date_values(ptr + 1, end, vals, NVALS) if ptr == NULL: return 0 diff --git a/tests/types/test_date.py b/tests/types/test_date.py index 3a969e4c9..b58302115 100644 --- a/tests/types/test_date.py +++ b/tests/types/test_date.py @@ -337,6 +337,25 @@ class TestDateTimeTz: assert rec[0] is True, type assert rec[1] == val + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '2000-01-01 01:02:03.123456-10:20'::timestamptz, + '11111111'::int4 + ) to stdout + """ + ) as copy: + copy.set_types(["timestamptz", "int4"]) + rec = copy.read_row() + + tz = dt.timezone(-dt.timedelta(hours=10, minutes=20)) + want = dt.datetime(2000, 1, 1, 1, 2, 3, 123456, tzinfo=tz) + assert rec[0] == want + assert rec[1] == 11111111 + class TestTime: @pytest.mark.parametrize( @@ -456,6 +475,25 @@ class TestTimeTz: assert rec[0] is True, type assert rec[1] == val + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '01:02:03.123456-10:20'::timetz, + '11111111'::int4 + ) to stdout + """ + ) as copy: + copy.set_types(["timetz", "int4"]) + rec = copy.read_row() + + tz = dt.timezone(-dt.timedelta(hours=10, minutes=20)) + want = dt.time(1, 2, 3, 123456, tzinfo=tz) + assert rec[0] == want + assert rec[1] == 11111111 + class TestInterval: dump_timedelta_samples = [ @@ -574,6 +612,24 @@ class TestInterval: ).fetchone() assert rec == (dt.date(2020, 12, 31), dt.date(9999, 12, 31)) + def test_load_copy(self, conn): + cur = conn.cursor(binary=False) + with cur.copy( + """ + copy ( + select + '1 days +00:00:01.000001'::interval, + 'foo bar'::text + ) to stdout + """ + ) as copy: + copy.set_types(["interval", "text"]) + rec = copy.read_row() + + want = dt.timedelta(days=1, seconds=1, microseconds=1) + assert rec[0] == want + assert rec[1] == "foo bar" + # # Support -- 2.47.3