]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support subscript for hstore
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Oct 2025 18:08:05 +0000 (14:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Nov 2025 01:11:49 +0000 (21:11 -0400)
Added support for PostgreSQL 14+ HSTORE subscripting syntax. When connected
to PostgreSQL 14 or later, HSTORE columns now automatically use the native
subscript notation ``hstore_col['key']`` instead of the arrow operator
``hstore_col -> 'key'`` for both read and write operations. This provides
better compatibility with PostgreSQL's native HSTORE subscripting feature
while maintaining backward compatibility with older PostgreSQL versions.

as part of this change we add a new parameter to custom_op "visit_name"
which allows a custom op to refer to a specific visit method in a
dialect's compiler.

Fixes: #12948
Change-Id: Id98d333fe78e31d9c7679cb2902f1c7e458d6e11

doc/build/changelog/unreleased_21/12948.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/operators.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_query.py
test/dialect/postgresql/test_types.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_21/12948.rst b/doc/build/changelog/unreleased_21/12948.rst
new file mode 100644 (file)
index 0000000..cd65877
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 12948
+
+    Added support for PostgreSQL 14+ HSTORE subscripting syntax. When connected
+    to PostgreSQL 14 or later, HSTORE columns now automatically use the native
+    subscript notation ``hstore_col['key']`` instead of the arrow operator
+    ``hstore_col -> 'key'`` for both read and write operations. This provides
+    better compatibility with PostgreSQL's native HSTORE subscripting feature
+    while maintaining backward compatibility with older PostgreSQL versions.
+
+    .. warning:: Indexes in existing PostgreSQL databases which were indexed
+       on an HSTORE subscript expression would need to be updated in order to
+       match the new SQL syntax.
+
+
index 2e4b10c5998612bbe9a2352a3878b1ead3e6b640..e973d28feffecab0893a3b02defac5ccfaf367db 100644 (file)
@@ -2070,6 +2070,23 @@ class PGCompiler(compiler.SQLCompiler):
             binary, " #> " if not _cast_applied else " #>> ", **kw
         )
 
+    def visit_hstore_getitem_op_binary(self, binary, operator, **kw):
+        kw["eager_grouping"] = True
+
+        if self.dialect._supports_jsonb_subscripting:
+            # use subscript notation: col['key'] instead of col -> 'key'
+            # For function calls, wrap in parentheses: (func())[key]
+            left_str = self.process(binary.left, **kw)
+            if isinstance(binary.left, sql.functions.FunctionElement):
+                left_str = f"({left_str})"
+            return "%s[%s]" % (
+                left_str,
+                self.process(binary.right, **kw),
+            )
+        else:
+            # Fall back to arrow notation for older versions
+            return self._generate_generic_binary(binary, " -> ", **kw)
+
     def visit_getitem_binary(self, binary, operator, **kw):
         return "%s[%s]" % (
             self.process(binary.left, **kw),
index ebcafcba991ecc41c14ca1718641db9a9f7fa90e..2e761139a637a5d11fd7f28f6b423a2b4c9850fd 100644 (file)
@@ -126,4 +126,5 @@ GETITEM = operators.custom_op(
     precedence=_getitem_precedence,
     natural_self_precedent=True,
     eager_grouping=True,
+    visit_name="hstore_getitem",
 )
index 9d2a9e3bd1204cc443b781b8d0f67151c6622eb7..8e50da46b34d7850335baa322868a59ec2656140 100644 (file)
@@ -3163,6 +3163,10 @@ class SQLCompiler(Compiled):
         )
         return getattr(self, attrname, None)
 
+    def _get_custom_operator_dispatch(self, operator_, qualifier1):
+        attrname = "visit_%s_op_%s" % (operator_.visit_name, qualifier1)
+        return getattr(self, attrname, None)
+
     def visit_unary(
         self, unary, add_to_result_map=None, result_map_targets=(), **kw
     ):
@@ -3527,6 +3531,11 @@ class SQLCompiler(Compiled):
             )
 
     def visit_custom_op_binary(self, element, operator, **kw):
+        if operator.visit_name:
+            disp = self._get_custom_operator_dispatch(operator, "binary")
+            if disp:
+                return disp(element, operator, **kw)
+
         kw["eager_grouping"] = operator.eager_grouping
         return self._generate_generic_binary(
             element,
@@ -3535,11 +3544,21 @@ class SQLCompiler(Compiled):
         )
 
     def visit_custom_op_unary_operator(self, element, operator, **kw):
+        if operator.visit_name:
+            disp = self._get_custom_operator_dispatch(operator, "unary")
+            if disp:
+                return disp(element, operator, **kw)
+
         return self._generate_generic_unary_operator(
             element, self.escape_literal_column(operator.opstring) + " ", **kw
         )
 
     def visit_custom_op_unary_modifier(self, element, operator, **kw):
+        if operator.visit_name:
+            disp = self._get_custom_operator_dispatch(operator, "unary")
+            if disp:
+                return disp(element, operator, **kw)
+
         return self._generate_generic_unary_modifier(
             element, " " + self.escape_literal_column(operator.opstring), **kw
         )
index 39614b917876dccd9c5d41a84f5fe694b89c8a83..cd4ada2a54792f85d2c8ede8fbb2bbb673d40d28 100644 (file)
@@ -864,6 +864,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
             return_type: _TypeEngineArgument[_OPT],
             python_impl: Optional[Callable[..., Any]] = None,
             operator_class: OperatorClass = ...,
+            visit_name: Optional[str] = ...,
         ) -> Callable[[Any], BinaryExpression[_OPT]]: ...
 
         @overload
@@ -875,6 +876,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
             return_type: Optional[_TypeEngineArgument[Any]] = ...,
             python_impl: Optional[Callable[..., Any]] = ...,
             operator_class: OperatorClass = ...,
+            visit_name: Optional[str] = ...,
         ) -> Callable[[Any], BinaryExpression[Any]]: ...
 
         def op(
@@ -885,6 +887,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
             return_type: Optional[_TypeEngineArgument[Any]] = None,
             python_impl: Optional[Callable[..., Any]] = None,
             operator_class: OperatorClass = OperatorClass.BASE,
+            visit_name: Optional[str] = None,
         ) -> Callable[[Any], BinaryExpression[Any]]: ...
 
         def bool_op(
index 9d4d86a341a54388de1900f1b87fba1a2a65c8a5..3bd6427e79d21d11573fb6460c36dc8230aa5dc7 100644 (file)
@@ -278,6 +278,7 @@ class Operators:
         ] = None,
         python_impl: Optional[Callable[..., Any]] = None,
         operator_class: OperatorClass = OperatorClass.BASE,
+        visit_name: Optional[str] = None,
     ) -> Callable[[Any], Operators]:
         """Produce a generic operator function.
 
@@ -369,6 +370,17 @@ class Operators:
 
             :ref:`relationship_custom_operator`
 
+        :param visit_name: string name indicating a series of methods that
+         maybe be implemented by a :class:`.Dialect`, specifically using its
+         :class:`.SQLCompiler` implementation.  The naming scheme is
+         ``visit_<visit_name>_op_[binary|unary]``; e.g. using the visit name
+         ``hstore`` means that a binary expression using the operator will
+         attempt to locate a method ``visit_hstore_op_binary()`` on the
+         target dialect's compiler class, which can then provide a compilation
+         string for the full binary expression.
+
+         .. versionadded:: 2.1
+
         """
         operator = custom_op(
             opstring,
@@ -377,6 +389,7 @@ class Operators:
             return_type=return_type,
             python_impl=python_impl,
             operator_class=operator_class,
+            visit_name=visit_name,
         )
 
         def against(other: Any) -> Operators:
@@ -488,6 +501,7 @@ class custom_op(OperatorType, Generic[_T]):
         "return_type",
         "python_impl",
         "operator_class",
+        "visit_name",
     )
 
     def __init__(
@@ -503,7 +517,13 @@ class custom_op(OperatorType, Generic[_T]):
         eager_grouping: bool = False,
         python_impl: Optional[Callable[..., Any]] = None,
         operator_class: OperatorClass = OperatorClass.BASE,
+        visit_name: Optional[str] = None,
     ):
+        """Create a new :class:`.custom_op`.
+
+        See :meth:`.Operators.op` for parameter information.
+
+        """
         self.opstring = opstring
         self.precedence = precedence
         self.is_comparison = is_comparison
@@ -514,6 +534,7 @@ class custom_op(OperatorType, Generic[_T]):
         )
         self.python_impl = python_impl
         self.operator_class = operator_class
+        self.visit_name = visit_name
 
     def __eq__(self, other: Any) -> bool:
         return (
index 175b099940b714f6119fffd718d755288f265036..ed1bece524c44534a8f52d3b330cf0232d1cda5e 100644 (file)
@@ -50,6 +50,7 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
 from sqlalchemy.dialects.postgresql import distinct_on
 from sqlalchemy.dialects.postgresql import DOMAIN
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
+from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import JSON
 from sqlalchemy.dialects.postgresql import JSONB
@@ -2903,6 +2904,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE data SET x -> %(x_1)s=(data.x -> %(x_2)s)",
         )
 
+    @testing.variation("pgversion", ["pg14", "pg13"])
+    def test_hstore_subscripting(self, pgversion):
+        """test #12948 - PostgreSQL 14+ HSTORE subscripting syntax"""
+        data = table("data", column("id", Integer), column("h", HSTORE))
+
+        dialect = postgresql.dialect()
+
+        if pgversion.pg13:
+            dialect._supports_jsonb_subscripting = False
+
+        # Test SELECT with HSTORE indexing
+        stmt = select(data.c.h["key"])
+        self.assert_compile(
+            stmt,
+            (
+                "SELECT data.h[%(h_1)s] AS anon_1 FROM data"
+                if pgversion.pg14
+                else "SELECT data.h -> %(h_1)s AS anon_1 FROM data"
+            ),
+            dialect=dialect,
+        )
+
+        # Test UPDATE with HSTORE indexing (the original issue case)
+        stmt = update(data).values({data.c.h["new_key"]: data.c.h["old_key"]})
+        self.assert_compile(
+            stmt,
+            (
+                "UPDATE data SET h[%(h_1)s]=(data.h[%(h_2)s])"
+                if pgversion.pg14
+                else "UPDATE data SET h -> %(h_1)s=(data.h -> %(h_2)s)"
+            ),
+            dialect=dialect,
+        )
+
     def test_jsonb_functions_use_parentheses_with_subscripting(self):
         """test #12778 - JSONB functions are parenthesized with [] syntax"""
         data = table("data", column("id", Integer), column("x", JSONB))
index fc9f7f79188f8111864106c0ecface23b66d76f2..3e392e8fd239bf132353df1dac562011cf47a919 100644 (file)
@@ -29,6 +29,7 @@ from sqlalchemy import tuple_
 from sqlalchemy import Uuid
 from sqlalchemy import values
 from sqlalchemy.dialects import postgresql
+from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import REGCONFIG
 from sqlalchemy.sql.expression import type_coerce
@@ -2001,3 +2002,150 @@ class JSONUpdateTest(fixtures.TablesTest):
             row.jb,
             {"tags": ["python", "postgresql", "postgres"], "priority": "high"},
         )
+
+
+class HstoreUpdateTest(fixtures.TablesTest):
+    """round trip tests related to using HSTORE in UPDATE statements
+    with PG-specific features
+
+    """
+
+    __only_on__ = "postgresql"
+    __backend__ = True
+    __requires__ = ("native_hstore",)
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("h", HSTORE),
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(
+            cls.tables["t"].insert(),
+            [
+                {"id": 1, "h": {"k1": "v1", "k2": "v2"}},
+                {"id": 2, "h": {"k3": "v3", "k4": "v4"}},
+            ],
+        )
+
+    @testing.only_on("postgresql>=14")
+    def test_hstore_element_update_basic(self, connection):
+        """Test updating individual HSTORE elements with subscript syntax
+
+        test #12948
+
+        """
+        t = self.tables["t"]
+
+        # Insert test data with HSTORE
+        connection.execute(
+            t.insert(),
+            [
+                {
+                    "id": 10,
+                    "h": {"name": "Alice", "status": "active"},
+                },
+                {
+                    "id": 11,
+                    "h": {"name": "Bob", "status": "inactive"},
+                },
+            ],
+        )
+
+        # Update specific elements using HSTORE subscript syntax
+        # This tests the new HSTORE subscripting feature from issue #12948
+        connection.execute(
+            t.update()
+            .values({t.c.h["name"]: "Alice Updated"})
+            .where(t.c.id == 10)
+        )
+
+        connection.execute(
+            t.update().values({t.c.h["status"]: "active"}).where(t.c.id == 11)
+        )
+
+        results = connection.execute(
+            t.select().where(t.c.id.in_([10, 11])).order_by(t.c.id)
+        )
+
+        eq_(
+            [row.h for row in results],
+            [
+                {"name": "Alice Updated", "status": "active"},
+                {"name": "Bob", "status": "active"},
+            ],
+        )
+
+    @testing.only_on("postgresql>=14")
+    def test_hstore_element_update_multiple_keys(self, connection):
+        """Test updating multiple HSTORE elements in a single statement
+
+        test #12948
+
+        """
+        t = self.tables["t"]
+
+        connection.execute(
+            t.insert(),
+            {
+                "id": 20,
+                "h": {
+                    "config_theme": "dark",
+                    "config_lang": "en",
+                    "version": "1",
+                },
+            },
+        )
+
+        # Update multiple elements at once
+        connection.execute(
+            t.update()
+            .values({t.c.h["config_theme"]: "light", t.c.h["version"]: "2"})
+            .where(t.c.id == 20)
+        )
+
+        # Verify the updates
+        row = connection.execute(t.select().where(t.c.id == 20)).one()
+
+        eq_(
+            row.h,
+            {"config_theme": "light", "config_lang": "en", "version": "2"},
+        )
+
+    @testing.only_on("postgresql>=14")
+    def test_hstore_element_update_new_key(self, connection):
+        """Test adding new keys to HSTORE using subscript syntax
+
+        test #12948
+
+        """
+        t = self.tables["t"]
+
+        # Insert test data
+        connection.execute(
+            t.insert(),
+            {
+                "id": 30,
+                "h": {"existing_key": "existing_value"},
+            },
+        )
+
+        # Add a new key using subscript syntax
+        connection.execute(
+            t.update()
+            .values({t.c.h["new_key"]: "new_value"})
+            .where(t.c.id == 30)
+        )
+
+        # Verify the update
+        row = connection.execute(t.select().where(t.c.id == 30)).fetchone()
+
+        eq_(
+            row.h,
+            {"existing_key": "existing_value", "new_key": "new_value"},
+        )
index 6cbb3bf481981254c75f1dcfd4232e4b54a08402..6ef25fe363355e10ab9bcc001b196d0f78ccd8db 100644 (file)
@@ -4253,21 +4253,17 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
     def test_where_getitem(self):
         self._test_where(
             self.hashcol["bar"] == None,  # noqa
-            "(test_table.hash -> %(hash_1)s) IS NULL",
+            "test_table.hash[%(hash_1)s] IS NULL",
         )
 
     def test_where_getitem_any(self):
         self._test_where(
             self.hashcol["bar"] == any_(array(["foo"])),  # noqa
-            "(test_table.hash -> %(hash_1)s) = ANY (ARRAY[%(param_1)s])",
+            "test_table.hash[%(hash_1)s] = ANY (ARRAY[%(param_1)s])",
         )
 
+    # Test combinations that don't use subscript operator
     @testing.combinations(
-        (
-            lambda self: self.hashcol["foo"],
-            "test_table.hash -> %(hash_1)s AS anon_1",
-            True,
-        ),
         (
             lambda self: self.hashcol.delete("foo"),
             "delete(test_table.hash, %(delete_2)s) AS delete_1",
@@ -4297,29 +4293,6 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
             ),
             True,
         ),
-        (
-            lambda self: hstore("foo", "3")["foo"],
-            "hstore(%(hstore_1)s, %(hstore_2)s) -> %(hstore_3)s AS anon_1",
-            False,
-        ),
-        (
-            lambda self: hstore(
-                postgresql.array(["1", "2"]), postgresql.array(["3", None])
-            )["1"],
-            (
-                "hstore(ARRAY[%(param_1)s, %(param_2)s], "
-                "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"
-            ),
-            False,
-        ),
-        (
-            lambda self: hstore(postgresql.array(["1", "2", "3", None]))["3"],
-            (
-                "hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) "
-                "-> %(hstore_1)s AS anon_1"
-            ),
-            False,
-        ),
         (
             lambda self: self.hashcol.concat(
                 hstore(cast(self.test_table.c.id, Text), "3")
@@ -4335,16 +4308,6 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
             "hstore(%(hstore_1)s, %(hstore_2)s) || test_table.hash AS anon_1",
             True,
         ),
-        (
-            lambda self: (self.hashcol + self.hashcol)["foo"],
-            "(test_table.hash || test_table.hash) -> %(param_1)s AS anon_1",
-            True,
-        ),
-        (
-            lambda self: self.hashcol["foo"] != None,  # noqa
-            "(test_table.hash -> %(hash_1)s) IS NOT NULL AS anon_1",
-            True,
-        ),
         (
             # hide from 2to3
             lambda self: getattr(self.hashcol, "keys")(),
@@ -4375,6 +4338,55 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
             ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected,
         )
 
+    # Test combinations that use subscript operator (PG 14+ uses [] syntax)
+    @testing.combinations(
+        (
+            lambda self: self.hashcol["foo"],
+            "test_table.hash[%(hash_1)s] AS anon_1",
+            True,
+        ),
+        (
+            lambda self: hstore("foo", "3")["foo"],
+            "(hstore(%(hstore_1)s, %(hstore_2)s))[%(hstore_3)s] AS anon_1",
+            False,
+        ),
+        (
+            lambda self: hstore(
+                postgresql.array(["1", "2"]), postgresql.array(["3", None])
+            )["1"],
+            (
+                "(hstore(ARRAY[%(param_1)s, %(param_2)s], "
+                "ARRAY[%(param_3)s, NULL]))[%(hstore_1)s] AS anon_1"
+            ),
+            False,
+        ),
+        (
+            lambda self: hstore(postgresql.array(["1", "2", "3", None]))["3"],
+            (
+                "(hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]))"
+                "[%(hstore_1)s] AS anon_1"
+            ),
+            False,
+        ),
+        (
+            lambda self: (self.hashcol + self.hashcol)["foo"],
+            "(test_table.hash || test_table.hash)[%(param_1)s] AS anon_1",
+            True,
+        ),
+        (
+            lambda self: self.hashcol["foo"] != None,  # noqa
+            "test_table.hash[%(hash_1)s] IS NOT NULL AS anon_1",
+            True,
+        ),
+    )
+    def test_cols_subscript(self, colclause_fn, expected, from_):
+        colclause = colclause_fn(self)
+        stmt = select(colclause)
+        self.assert_compile(
+            stmt,
+            ("SELECT %s" + (" FROM test_table" if from_ else "")) % expected,
+        )
+
 
 class HStoreRoundTripTest(fixtures.TablesTest):
     __requires__ = ("hstore",)
index 8de22b89dbc6b4bf68671d183075a576198d037f..007563f367f01da08786b79ed2601631e333a550 100644 (file)
@@ -4655,6 +4655,77 @@ class CustomOpTest(fixtures.TestBase):
             col == "test"
 
 
+class CustomOpDialectCompileTest(
+    testing.AssertsCompiledSQL, fixtures.TestBase
+):
+    """test new custom op dispatch feature added as part of #12948"""
+
+    @testing.fixture
+    def dialect_fixture(self):
+
+        class MyCompiler(compiler.SQLCompiler):
+            def visit_myop_op_binary(self, binary, operator, **kw):
+                return "|%s| ->%s<-" % (
+                    self.process(binary.left, **kw),
+                    self.process(binary.right, **kw),
+                )
+
+            def visit_myop_op_unary(self, unary, operator, **kw):
+                if operator is unary.modifier:
+                    return "%s->|" % (self.process(unary.element, **kw))
+                elif operator is unary.operator:
+                    return "|->%s" % (self.process(unary.element, **kw))
+
+        class MyDialect(default.DefaultDialect):
+            statement_compiler = MyCompiler
+
+        myop = operators.custom_op(
+            "---",
+            precedence=15,
+            natural_self_precedent=True,
+            eager_grouping=True,
+            visit_name="myop",
+        )
+        return MyDialect, myop
+
+    @testing.variation("dialect", ["default", "custom"])
+    def test_binary_override(self, dialect_fixture, dialect):
+        MyDialect, myop = dialect_fixture
+
+        if dialect.default:
+            self.assert_compile(
+                myop(column("q", String), column("y", String)), "q --- y"
+            )
+        elif dialect.custom:
+            self.assert_compile(
+                myop(column("q", String), column("y", String)),
+                "|q| ->y<-",
+                dialect=MyDialect(),
+            )
+
+    @testing.variation("dialect", ["default", "custom"])
+    def test_unary_modifier_override(self, dialect_fixture, dialect):
+        MyDialect, myop = dialect_fixture
+
+        unary = UnaryExpression(column("zqr"), modifier=myop, type_=Numeric)
+
+        if dialect.default:
+            self.assert_compile(unary, "zqr ---")
+        elif dialect.custom:
+            self.assert_compile(unary, "zqr->|", dialect=MyDialect())
+
+    @testing.variation("dialect", ["default", "custom"])
+    def test_unary_operator_override(self, dialect_fixture, dialect):
+        MyDialect, myop = dialect_fixture
+
+        unary = UnaryExpression(column("zqr"), operator=myop, type_=Numeric)
+
+        if dialect.default:
+            self.assert_compile(unary, "--- zqr")
+        elif dialect.custom:
+            self.assert_compile(unary, "|->zqr", dialect=MyDialect())
+
+
 class TupleTypingTest(fixtures.TestBase):
     def _assert_types(self, expr):
         eq_(expr[0]._type_affinity, Integer)