From: Federico Caselli Date: Fri, 10 Dec 2021 13:18:34 +0000 (+0100) Subject: Improve array of enum handling. X-Git-Tag: rel_2_0_0b1~563^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=94afc4f5fc842160468cf7175552125eebf7a510;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve array of enum handling. Fixed handling of array of enum values which require escape characters. Fixes: #7418 Change-Id: I50525846f6029dfea9a8ad1cb913424d168d5f62 --- diff --git a/doc/build/changelog/unreleased_14/7418.rst b/doc/build/changelog/unreleased_14/7418.rst new file mode 100644 index 0000000000..e1e192571d --- /dev/null +++ b/doc/build/changelog/unreleased_14/7418.rst @@ -0,0 +1,5 @@ +.. change:: + :tags: bug, postgresql + :tickets: 7418 + + Fixed handling of array of enum values which require escape characters. diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index a8010c0fad..f3e82c9354 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -364,10 +364,11 @@ class ARRAY(sqltypes.ARRAY): if self._against_native_enum: super_rp = process + pattern = re.compile(r"^{(.*)}$") def handle_raw_string(value): - inner = re.match(r"^{(.*)}$", value).group(1) - return inner.split(",") if inner else [] + inner = pattern.match(value).group(1) + return _split_enum_values(inner) def process(value): if value is None: @@ -382,3 +383,27 @@ class ARRAY(sqltypes.ARRAY): ) return process + + +def _split_enum_values(array_string): + if '"' not in array_string: + # no escape char is present so it can just split on the comma + return array_string.split(",") + + # handles quoted strings from: + # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr' + # returns + # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr'] + text = array_string.replace(r"\"", "_$ESC_QUOTE$_") + text = text.replace(r"\\", "\\") + result = [] + on_quotes = re.split(r'(")', text) + in_quotes = False + for tok in on_quotes: + if tok == '"': + in_quotes = not in_quotes + elif in_quotes: + result.append(tok.replace("_$ESC_QUOTE$_", '"')) + else: + result.extend(re.findall(r"([^\s,]+),?", tok)) + return result diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index a5797dc2f0..69b06403e4 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1957,6 +1957,23 @@ class ArrayRoundTripTest: def __ne__(self, other): return not self.__eq__(other) + difficult_enum = [ + "Value", + "With space", + "With,comma", + 'With"quote', + "With\\escape", + """Various!@#$%^*()"'\\][{};:.<>|_+~chars""", + ] + + def make_difficult_enum(cls_, native): + return cls_( + *difficult_enum, name="difficult_enum", native_enum=native + ) + + def difficult_enum_values(x): + return [v for i, v in enumerate(difficult_enum) if i != x - 1] + elements = [ (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]), (sqltypes.Text, str_values), @@ -2044,6 +2061,18 @@ class ArrayRoundTripTest: (sqltypes.Enum(AnEnum, native_enum=True), enum_values), (sqltypes.Enum(AnEnum, native_enum=False), enum_values), (postgresql.ENUM(AnEnum, native_enum=True), enum_values), + ( + make_difficult_enum(sqltypes.Enum, native=True), + difficult_enum_values, + ), + ( + make_difficult_enum(sqltypes.Enum, native=False), + difficult_enum_values, + ), + ( + make_difficult_enum(postgresql.ENUM, native=True), + difficult_enum_values, + ), ] if not exclude_json: