From 547f5f8b43542b7da97e4b43ac759ba1067c1ce9 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 5 Sep 2025 09:29:34 -0400 Subject: [PATCH] interpret NULL in PG enum array values Fixed issue where selecting an enum array column containing NULL values would fail to parse properly in the PostgreSQL dialect. The :func:`._split_enum_values` function now correctly handles NULL entries by converting them to Python ``None`` values. Fixes: #12847 Change-Id: I39d10bc1be6b458da7e5d3f4b740f8faafd0adc5 (cherry picked from commit 9eb35c6664094a8e2b7ca1a0794f3cfd65cd46cf) --- doc/build/changelog/unreleased_20/12847.rst | 8 ++ lib/sqlalchemy/dialects/postgresql/array.py | 17 ++- test/dialect/postgresql/test_types.py | 140 +++++++++++++++----- 3 files changed, 125 insertions(+), 40 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/12847.rst diff --git a/doc/build/changelog/unreleased_20/12847.rst b/doc/build/changelog/unreleased_20/12847.rst new file mode 100644 index 0000000000..bba7849d3e --- /dev/null +++ b/doc/build/changelog/unreleased_20/12847.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, postgresql + :tickets: 12847 + + Fixed issue where selecting an enum array column containing NULL values + would fail to parse properly in the PostgreSQL dialect. The + :func:`._split_enum_values` function now correctly handles NULL entries by + converting them to Python ``None`` values. diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 96f6dc21a2..7339d5c7d8 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -464,7 +464,7 @@ class ARRAY(sqltypes.ARRAY[_T]): super_rp = process pattern = re.compile(r"^{(.*)}$") - def handle_raw_string(value: str) -> list[str]: + def handle_raw_string(value: str) -> Sequence[Optional[str]]: inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501 return _split_enum_values(inner) @@ -485,10 +485,13 @@ class ARRAY(sqltypes.ARRAY[_T]): return process -def _split_enum_values(array_string: str) -> list[str]: +def _split_enum_values(array_string: str) -> Sequence[Optional[str]]: if '"' not in array_string: # no escape char is present so it can just split on the comma - return array_string.split(",") if array_string else [] + return [ + r if r != "NULL" else None + for r in (array_string.split(",") if array_string else []) + ] # handles quoted strings from: # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr' @@ -505,5 +508,11 @@ def _split_enum_values(array_string: str) -> list[str]: elif in_quotes: result.append(tok.replace("_$ESC_QUOTE$_", '"')) else: - result.extend(re.findall(r"([^\s,]+),?", tok)) + # interpret NULL (without quotes!) as None + result.extend( + [ + r if r != "NULL" else None + for r in re.findall(r"([^\s,]+),?", tok) + ] + ) return result diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 4613ebf32c..6d1bdcfa38 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1,6 +1,7 @@ import datetime import decimal from enum import Enum as _PY_Enum +import functools from ipaddress import IPv4Address from ipaddress import IPv4Network from ipaddress import IPv6Address @@ -2641,7 +2642,11 @@ class ArrayRoundTripTest: t.drop(connection) eq_(inspect(connection).get_enums(), []) - def _type_combinations(exclude_json=False, exclude_empty_lists=False): + def _type_combinations( + exclude_json=False, + exclude_empty_lists=False, + exclude_arrays_with_none=False, + ): def str_values(x): return ["one", "two: %s" % x, "three", "four", "five"] @@ -2693,22 +2698,25 @@ class ArrayRoundTripTest: def __ne__(self, other): return not self.__eq__(other) + simple_enum = ["one", "two", "three", "four", "five", "six"] difficult_enum = [ "Value", "With space", "With,comma", + "NULL", 'With"quote', "With\\escape", """Various!@#$%^*()"'\\][{};:.<>|_+~chars""", ] - def make_difficult_enum(cls_, native): - return cls_( - *difficult_enum, name="difficult_enum", native_enum=native - ) + def make_enum(cls_, members, native): + return cls_(*members, name="difficult_enum", native_enum=native) - def difficult_enum_values(x): - return [v for i, v in enumerate(difficult_enum) if i != x - 1] + def make_enum_values(members, x, *, include_none=False): + arr = [v for i, v in enumerate(members) if i != x - 1] + if include_none: + arr.insert(2, None) + return arr elements = [ (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]), @@ -2801,19 +2809,71 @@ class ArrayRoundTripTest: enum_values, ), ( - make_difficult_enum(sqltypes.Enum, native=True), - difficult_enum_values, + make_enum(sqltypes.Enum, difficult_enum, native=True), + functools.partial(make_enum_values, difficult_enum), ), ( - make_difficult_enum(sqltypes.Enum, native=False), - difficult_enum_values, + make_enum(sqltypes.Enum, difficult_enum, native=False), + functools.partial(make_enum_values, difficult_enum), ), ( - make_difficult_enum(postgresql.ENUM, native=True), - difficult_enum_values, + make_enum(postgresql.ENUM, difficult_enum, native=True), + functools.partial(make_enum_values, difficult_enum), ), ] + if not exclude_arrays_with_none: + elements.extend( + [ + ( + # unquoted ENUM values including NULL in the data + make_enum(sqltypes.Enum, simple_enum, native=True), + functools.partial( + make_enum_values, simple_enum, include_none=True + ), + ), + ( + # unquoted ENUM values including NULL in the data + make_enum(sqltypes.Enum, simple_enum, native=False), + functools.partial( + make_enum_values, simple_enum, include_none=True + ), + ), + ( + # unquoted ENUM values including NULL in the data + make_enum(postgresql.ENUM, simple_enum, native=True), + functools.partial( + make_enum_values, simple_enum, include_none=True + ), + ), + ( + # quoted ENUM values, including both + # quoted "NULL" and real NULL in the data + make_enum(sqltypes.Enum, difficult_enum, native=True), + functools.partial( + make_enum_values, difficult_enum, include_none=True + ), + ), + ( + # quoted ENUM values, including both + # quoted "NULL" and real NULL in the data + make_enum(sqltypes.Enum, difficult_enum, native=False), + functools.partial( + make_enum_values, difficult_enum, include_none=True + ), + ), + ( + # quoted ENUM values, including both + # quoted "NULL" and real NULL in the data + make_enum( + postgresql.ENUM, difficult_enum, native=True + ), + functools.partial( + make_enum_values, difficult_enum, include_none=True + ), + ), + ] + ) if not exclude_empty_lists: elements.extend( [ @@ -2853,7 +2913,7 @@ class ArrayRoundTripTest: elements[i] = elem return testing.combinations_list( - elements, argnames="type_,gen", id_="na" + elements, argnames="type_,generate_data", id_="na" ) @classmethod @@ -2879,10 +2939,13 @@ class ArrayRoundTripTest: meta.create_all(connection) - def go(gen): + def go(generate_data): connection.execute( table.insert(), - [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}], + [ + {"id": 1, "bar": generate_data(1)}, + {"id": 2, "bar": generate_data(2)}, + ], ) return table @@ -2890,23 +2953,23 @@ class ArrayRoundTripTest: @_type_combinations() def test_type_specific_value_select( - self, type_specific_fixture, connection, type_, gen + self, type_specific_fixture, connection, type_, generate_data ): - table = type_specific_fixture(gen) + table = type_specific_fixture(generate_data) rows = connection.execute( select(table.c.bar).order_by(table.c.id) ).all() - eq_(rows, [(gen(1),), (gen(2),)]) + eq_(rows, [(generate_data(1),), (generate_data(2),)]) @_type_combinations() def test_type_specific_value_update( - self, type_specific_fixture, connection, type_, gen + self, type_specific_fixture, connection, type_, generate_data ): - table = type_specific_fixture(gen) + table = type_specific_fixture(generate_data) - new_gen = gen(3) + new_gen = generate_data(3) connection.execute( table.update().where(table.c.id == 2).values(bar=new_gen) ) @@ -2918,11 +2981,11 @@ class ArrayRoundTripTest: @_type_combinations(exclude_empty_lists=True) def test_type_specific_slice_update( - self, type_specific_fixture, connection, type_, gen + self, type_specific_fixture, connection, type_, generate_data ): - table = type_specific_fixture(gen) + table = type_specific_fixture(generate_data) - new_gen = gen(3) + new_gen = generate_data(3) if not table.c.bar.type._variant_mapping: # this is not likely to occur to users but we need to just @@ -2938,18 +3001,18 @@ class ArrayRoundTripTest: select(table.c.bar).order_by(table.c.id) ).all() - sliced_gen = gen(2) + sliced_gen = generate_data(2) sliced_gen[0:3] = new_gen[1:4] - eq_(rows, [(gen(1),), (sliced_gen,)]) + eq_(rows, [(generate_data(1),), (sliced_gen,)]) @_type_combinations(exclude_json=True, exclude_empty_lists=True) def test_type_specific_value_delete( - self, type_specific_fixture, connection, type_, gen + self, type_specific_fixture, connection, type_, generate_data ): - table = type_specific_fixture(gen) + table = type_specific_fixture(generate_data) - new_gen = gen(2) + new_gen = generate_data(2) connection.execute(table.delete().where(table.c.bar == new_gen)) @@ -2967,22 +3030,27 @@ class PGArrayRoundTripTest( ): ARRAY = postgresql.ARRAY - @ArrayRoundTripTest._cls_type_combinations(exclude_json=True) + @ArrayRoundTripTest._cls_type_combinations( + exclude_json=True, exclude_arrays_with_none=True + ) def test_type_specific_contains( - self, type_specific_fixture, connection, type_, gen + self, type_specific_fixture, connection, type_, generate_data ): - table = type_specific_fixture(gen) + table = type_specific_fixture(generate_data) connection.execute( table.insert(), - [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}], + [ + {"id": 1, "bar": generate_data(1)}, + {"id": 2, "bar": generate_data(2)}, + ], ) id_, value = connection.execute( - select(table).where(table.c.bar.contains(gen(1))) + select(table).where(table.c.bar.contains(generate_data(1))) ).first() eq_(id_, 1) - eq_(value, gen(1)) + eq_(value, generate_data(1)) @testing.combinations( (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct" -- 2.47.3