]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve array of enum handling.
authorFederico Caselli <cfederico87@gmail.com>
Fri, 10 Dec 2021 13:18:34 +0000 (14:18 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jan 2022 19:33:46 +0000 (14:33 -0500)
Fixed handling of array of enum values which require escape characters.

Fixes: #7418
Change-Id: I50525846f6029dfea9a8ad1cb913424d168d5f62

doc/build/changelog/unreleased_14/7418.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/7418.rst b/doc/build/changelog/unreleased_14/7418.rst
new file mode 100644 (file)
index 0000000..e1e1925
--- /dev/null
@@ -0,0 +1,5 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 7418
+
+    Fixed handling of array of enum values which require escape characters.
index a8010c0fadfc7931418c06f071743ccb6f075b34..f3e82c93540ccfdc66f7aa849816a75a19f04c39 100644 (file)
@@ -364,10 +364,11 @@ class ARRAY(sqltypes.ARRAY):
 
         if self._against_native_enum:
             super_rp = process
+            pattern = re.compile(r"^{(.*)}$")
 
             def handle_raw_string(value):
-                inner = re.match(r"^{(.*)}$", value).group(1)
-                return inner.split(",") if inner else []
+                inner = pattern.match(value).group(1)
+                return _split_enum_values(inner)
 
             def process(value):
                 if value is None:
@@ -382,3 +383,27 @@ class ARRAY(sqltypes.ARRAY):
                 )
 
         return process
+
+
+def _split_enum_values(array_string):
+    if '"' not in array_string:
+        # no escape char is present so it can just split on the comma
+        return array_string.split(",")
+
+    # handles quoted strings from:
+    # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
+    # returns
+    # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr']
+    text = array_string.replace(r"\"", "_$ESC_QUOTE$_")
+    text = text.replace(r"\\", "\\")
+    result = []
+    on_quotes = re.split(r'(")', text)
+    in_quotes = False
+    for tok in on_quotes:
+        if tok == '"':
+            in_quotes = not in_quotes
+        elif in_quotes:
+            result.append(tok.replace("_$ESC_QUOTE$_", '"'))
+        else:
+            result.extend(re.findall(r"([^\s,]+),?", tok))
+    return result
index a5797dc2f03e4e160808fd5eb2df56a90af03de4..69b06403e428c6455fa2fa336e07fca173c2f27c 100644 (file)
@@ -1957,6 +1957,23 @@ class ArrayRoundTripTest:
             def __ne__(self, other):
                 return not self.__eq__(other)
 
+        difficult_enum = [
+            "Value",
+            "With space",
+            "With,comma",
+            'With"quote',
+            "With\\escape",
+            """Various!@#$%^*()"'\\][{};:.<>|_+~chars""",
+        ]
+
+        def make_difficult_enum(cls_, native):
+            return cls_(
+                *difficult_enum, name="difficult_enum", native_enum=native
+            )
+
+        def difficult_enum_values(x):
+            return [v for i, v in enumerate(difficult_enum) if i != x - 1]
+
         elements = [
             (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]),
             (sqltypes.Text, str_values),
@@ -2044,6 +2061,18 @@ class ArrayRoundTripTest:
             (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=False), enum_values),
             (postgresql.ENUM(AnEnum, native_enum=True), enum_values),
+            (
+                make_difficult_enum(sqltypes.Enum, native=True),
+                difficult_enum_values,
+            ),
+            (
+                make_difficult_enum(sqltypes.Enum, native=False),
+                difficult_enum_values,
+            ),
+            (
+                make_difficult_enum(postgresql.ENUM, native=True),
+                difficult_enum_values,
+            ),
         ]
 
         if not exclude_json: