From e916d00fd7c28424a22d424202643d258acb23d7 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Fri, 10 Dec 2021 14:18:34 +0100 Subject: [PATCH] Improve array of enum handling. Fixed handling of array of enum values which require escape characters. Fixes: #7418 Change-Id: I50525846f6029dfea9a8ad1cb913424d168d5f62 (cherry picked from commit 94afc4f5fc842160468cf7175552125eebf7a510) --- doc/build/changelog/unreleased_14/7418.rst | 5 ++++ lib/sqlalchemy/dialects/postgresql/array.py | 29 +++++++++++++++++++-- test/dialect/postgresql/test_types.py | 29 +++++++++++++++++++++ 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/7418.rst diff --git a/doc/build/changelog/unreleased_14/7418.rst b/doc/build/changelog/unreleased_14/7418.rst new file mode 100644 index 0000000000..e1e192571d --- /dev/null +++ b/doc/build/changelog/unreleased_14/7418.rst @@ -0,0 +1,5 @@ +.. change:: + :tags: bug, postgresql + :tickets: 7418 + + Fixed handling of array of enum values which require escape characters. diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index e57a4fc9ac..4f296e8ef1 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -367,10 +367,11 @@ class ARRAY(sqltypes.ARRAY): if self._against_native_enum: super_rp = process + pattern = re.compile(r"^{(.*)}$") def handle_raw_string(value): - inner = re.match(r"^{(.*)}$", value).group(1) - return inner.split(",") if inner else [] + inner = pattern.match(value).group(1) + return _split_enum_values(inner) def process(value): if value is None: @@ -385,3 +386,27 @@ class ARRAY(sqltypes.ARRAY): ) return process + + +def _split_enum_values(array_string): + if '"' not in array_string: + # no escape char is present so it can just split on the comma + return array_string.split(",") + + # handles quoted strings from: + # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr' + # returns + # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr'] + text = array_string.replace(r"\"", "_$ESC_QUOTE$_") + text = text.replace(r"\\", "\\") + result = [] + on_quotes = re.split(r'(")', text) + in_quotes = False + for tok in on_quotes: + if tok == '"': + in_quotes = not in_quotes + elif in_quotes: + result.append(tok.replace("_$ESC_QUOTE$_", '"')) + else: + result.extend(re.findall(r"([^\s,]+),?", tok)) + return result diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 4f26a6ef66..bbd5cadda1 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1954,6 +1954,23 @@ class ArrayRoundTripTest(object): def __ne__(self, other): return not self.__eq__(other) + difficult_enum = [ + "Value", + "With space", + "With,comma", + 'With"quote', + "With\\escape", + """Various!@#$%^*()"'\\][{};:.<>|_+~chars""", + ] + + def make_difficult_enum(cls_, native): + return cls_( + *difficult_enum, name="difficult_enum", native_enum=native + ) + + def difficult_enum_values(x): + return [v for i, v in enumerate(difficult_enum) if i != x - 1] + elements = [ (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]), (sqltypes.Text, str_values), @@ -2041,6 +2058,18 @@ class ArrayRoundTripTest(object): (sqltypes.Enum(AnEnum, native_enum=True), enum_values), (sqltypes.Enum(AnEnum, native_enum=False), enum_values), (postgresql.ENUM(AnEnum, native_enum=True), enum_values), + ( + make_difficult_enum(sqltypes.Enum, native=True), + difficult_enum_values, + ), + ( + make_difficult_enum(sqltypes.Enum, native=False), + difficult_enum_values, + ), + ( + make_difficult_enum(postgresql.ENUM, native=True), + difficult_enum_values, + ), ] if not exclude_json: -- 2.47.2