]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
interpret NULL in PG enum array values
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Sep 2025 13:29:34 +0000 (09:29 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Sep 2025 15:44:17 +0000 (11:44 -0400)
Fixed issue where selecting an enum array column containing NULL values
would fail to parse properly in the PostgreSQL dialect. The
:func:`._split_enum_values` function now correctly handles NULL entries by
converting them to Python ``None`` values.

Fixes: #12847
Change-Id: I39d10bc1be6b458da7e5d3f4b740f8faafd0adc5

doc/build/changelog/unreleased_20/12847.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/12847.rst b/doc/build/changelog/unreleased_20/12847.rst
new file mode 100644 (file)
index 0000000..bba7849
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 12847
+
+    Fixed issue where selecting an enum array column containing NULL values
+    would fail to parse properly in the PostgreSQL dialect. The
+    :func:`._split_enum_values` function now correctly handles NULL entries by
+    converting them to Python ``None`` values.
index 62042c6695295af23d7819705a63a1ed57a0165e..45b53f0d0494962c7eac7c748e38f3ee406485fd 100644 (file)
@@ -462,7 +462,7 @@ class ARRAY(sqltypes.ARRAY[_T]):
             super_rp = process
             pattern = re.compile(r"^{(.*)}$")
 
-            def handle_raw_string(value: str) -> list[str]:
+            def handle_raw_string(value: str) -> Sequence[Optional[str]]:
                 inner = pattern.match(value).group(1)  # type: ignore[union-attr]  # noqa: E501
                 return _split_enum_values(inner)
 
@@ -483,10 +483,13 @@ class ARRAY(sqltypes.ARRAY[_T]):
         return process
 
 
-def _split_enum_values(array_string: str) -> list[str]:
+def _split_enum_values(array_string: str) -> Sequence[Optional[str]]:
     if '"' not in array_string:
         # no escape char is present so it can just split on the comma
-        return array_string.split(",") if array_string else []
+        return [
+            r if r != "NULL" else None
+            for r in (array_string.split(",") if array_string else [])
+        ]
 
     # handles quoted strings from:
     # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
@@ -503,5 +506,11 @@ def _split_enum_values(array_string: str) -> list[str]:
         elif in_quotes:
             result.append(tok.replace("_$ESC_QUOTE$_", '"'))
         else:
-            result.extend(re.findall(r"([^\s,]+),?", tok))
+            # interpret NULL (without quotes!) as None
+            result.extend(
+                [
+                    r if r != "NULL" else None
+                    for r in re.findall(r"([^\s,]+),?", tok)
+                ]
+            )
     return result
index 42b537e8daf7e381533709e5154fe94307310395..c45b4bc9b2bd50ef530c150513c3185ce62f9c07 100644 (file)
@@ -1,6 +1,7 @@
 import datetime
 import decimal
 from enum import Enum as _PY_Enum
+import functools
 from ipaddress import IPv4Address
 from ipaddress import IPv4Network
 from ipaddress import IPv6Address
@@ -2642,7 +2643,11 @@ class ArrayRoundTripTest:
         t.drop(connection)
         eq_(inspect(connection).get_enums(), [])
 
-    def _type_combinations(exclude_json=False, exclude_empty_lists=False):
+    def _type_combinations(
+        exclude_json=False,
+        exclude_empty_lists=False,
+        exclude_arrays_with_none=False,
+    ):
         def str_values(x):
             return ["one", "two: %s" % x, "three", "four", "five"]
 
@@ -2694,22 +2699,25 @@ class ArrayRoundTripTest:
             def __ne__(self, other):
                 return not self.__eq__(other)
 
+        simple_enum = ["one", "two", "three", "four", "five", "six"]
         difficult_enum = [
             "Value",
             "With space",
             "With,comma",
+            "NULL",
             'With"quote',
             "With\\escape",
             """Various!@#$%^*()"'\\][{};:.<>|_+~chars""",
         ]
 
-        def make_difficult_enum(cls_, native):
-            return cls_(
-                *difficult_enum, name="difficult_enum", native_enum=native
-            )
+        def make_enum(cls_, members, native):
+            return cls_(*members, name="difficult_enum", native_enum=native)
 
-        def difficult_enum_values(x):
-            return [v for i, v in enumerate(difficult_enum) if i != x - 1]
+        def make_enum_values(members, x, *, include_none=False):
+            arr = [v for i, v in enumerate(members) if i != x - 1]
+            if include_none:
+                arr.insert(2, None)
+            return arr
 
         elements = [
             (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]),
@@ -2802,19 +2810,71 @@ class ArrayRoundTripTest:
                 enum_values,
             ),
             (
-                make_difficult_enum(sqltypes.Enum, native=True),
-                difficult_enum_values,
+                make_enum(sqltypes.Enum, difficult_enum, native=True),
+                functools.partial(make_enum_values, difficult_enum),
             ),
             (
-                make_difficult_enum(sqltypes.Enum, native=False),
-                difficult_enum_values,
+                make_enum(sqltypes.Enum, difficult_enum, native=False),
+                functools.partial(make_enum_values, difficult_enum),
             ),
             (
-                make_difficult_enum(postgresql.ENUM, native=True),
-                difficult_enum_values,
+                make_enum(postgresql.ENUM, difficult_enum, native=True),
+                functools.partial(make_enum_values, difficult_enum),
             ),
         ]
 
+        if not exclude_arrays_with_none:
+            elements.extend(
+                [
+                    (
+                        # unquoted ENUM values including NULL in the data
+                        make_enum(sqltypes.Enum, simple_enum, native=True),
+                        functools.partial(
+                            make_enum_values, simple_enum, include_none=True
+                        ),
+                    ),
+                    (
+                        # unquoted ENUM values including NULL in the data
+                        make_enum(sqltypes.Enum, simple_enum, native=False),
+                        functools.partial(
+                            make_enum_values, simple_enum, include_none=True
+                        ),
+                    ),
+                    (
+                        # unquoted ENUM values including NULL in the data
+                        make_enum(postgresql.ENUM, simple_enum, native=True),
+                        functools.partial(
+                            make_enum_values, simple_enum, include_none=True
+                        ),
+                    ),
+                    (
+                        # quoted ENUM values, including both
+                        # quoted "NULL" and real NULL in the data
+                        make_enum(sqltypes.Enum, difficult_enum, native=True),
+                        functools.partial(
+                            make_enum_values, difficult_enum, include_none=True
+                        ),
+                    ),
+                    (
+                        # quoted ENUM values, including both
+                        # quoted "NULL" and real NULL in the data
+                        make_enum(sqltypes.Enum, difficult_enum, native=False),
+                        functools.partial(
+                            make_enum_values, difficult_enum, include_none=True
+                        ),
+                    ),
+                    (
+                        # quoted ENUM values, including both
+                        # quoted "NULL" and real NULL in the data
+                        make_enum(
+                            postgresql.ENUM, difficult_enum, native=True
+                        ),
+                        functools.partial(
+                            make_enum_values, difficult_enum, include_none=True
+                        ),
+                    ),
+                ]
+            )
         if not exclude_empty_lists:
             elements.extend(
                 [
@@ -2854,7 +2914,7 @@ class ArrayRoundTripTest:
                 elements[i] = elem
 
         return testing.combinations_list(
-            elements, argnames="type_,gen", id_="na"
+            elements, argnames="type_,generate_data", id_="na"
         )
 
     @classmethod
@@ -2880,10 +2940,13 @@ class ArrayRoundTripTest:
 
         meta.create_all(connection)
 
-        def go(gen):
+        def go(generate_data):
             connection.execute(
                 table.insert(),
-                [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+                [
+                    {"id": 1, "bar": generate_data(1)},
+                    {"id": 2, "bar": generate_data(2)},
+                ],
             )
             return table
 
@@ -2891,23 +2954,23 @@ class ArrayRoundTripTest:
 
     @_type_combinations()
     def test_type_specific_value_select(
-        self, type_specific_fixture, connection, type_, gen
+        self, type_specific_fixture, connection, type_, generate_data
     ):
-        table = type_specific_fixture(gen)
+        table = type_specific_fixture(generate_data)
 
         rows = connection.execute(
             select(table.c.bar).order_by(table.c.id)
         ).all()
 
-        eq_(rows, [(gen(1),), (gen(2),)])
+        eq_(rows, [(generate_data(1),), (generate_data(2),)])
 
     @_type_combinations()
     def test_type_specific_value_update(
-        self, type_specific_fixture, connection, type_, gen
+        self, type_specific_fixture, connection, type_, generate_data
     ):
-        table = type_specific_fixture(gen)
+        table = type_specific_fixture(generate_data)
 
-        new_gen = gen(3)
+        new_gen = generate_data(3)
         connection.execute(
             table.update().where(table.c.id == 2).values(bar=new_gen)
         )
@@ -2919,11 +2982,11 @@ class ArrayRoundTripTest:
 
     @_type_combinations(exclude_empty_lists=True)
     def test_type_specific_slice_update(
-        self, type_specific_fixture, connection, type_, gen
+        self, type_specific_fixture, connection, type_, generate_data
     ):
-        table = type_specific_fixture(gen)
+        table = type_specific_fixture(generate_data)
 
-        new_gen = gen(3)
+        new_gen = generate_data(3)
 
         if not table.c.bar.type._variant_mapping:
             # this is not likely to occur to users but we need to just
@@ -2939,18 +3002,18 @@ class ArrayRoundTripTest:
             select(table.c.bar).order_by(table.c.id)
         ).all()
 
-        sliced_gen = gen(2)
+        sliced_gen = generate_data(2)
         sliced_gen[0:3] = new_gen[1:4]
 
-        eq_(rows, [(gen(1),), (sliced_gen,)])
+        eq_(rows, [(generate_data(1),), (sliced_gen,)])
 
     @_type_combinations(exclude_json=True, exclude_empty_lists=True)
     def test_type_specific_value_delete(
-        self, type_specific_fixture, connection, type_, gen
+        self, type_specific_fixture, connection, type_, generate_data
     ):
-        table = type_specific_fixture(gen)
+        table = type_specific_fixture(generate_data)
 
-        new_gen = gen(2)
+        new_gen = generate_data(2)
 
         connection.execute(table.delete().where(table.c.bar == new_gen))
 
@@ -2968,22 +3031,27 @@ class PGArrayRoundTripTest(
 ):
     ARRAY = postgresql.ARRAY
 
-    @ArrayRoundTripTest._cls_type_combinations(exclude_json=True)
+    @ArrayRoundTripTest._cls_type_combinations(
+        exclude_json=True, exclude_arrays_with_none=True
+    )
     def test_type_specific_contains(
-        self, type_specific_fixture, connection, type_, gen
+        self, type_specific_fixture, connection, type_, generate_data
     ):
-        table = type_specific_fixture(gen)
+        table = type_specific_fixture(generate_data)
 
         connection.execute(
             table.insert(),
-            [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+            [
+                {"id": 1, "bar": generate_data(1)},
+                {"id": 2, "bar": generate_data(2)},
+            ],
         )
 
         id_, value = connection.execute(
-            select(table).where(table.c.bar.contains(gen(1)))
+            select(table).where(table.c.bar.contains(generate_data(1)))
         ).first()
         eq_(id_, 1)
-        eq_(value, gen(1))
+        eq_(value, generate_data(1))
 
     @testing.combinations(
         (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"