]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Render labels in autogenerate index
authorFederico Caselli <cfederico87@gmail.com>
Sat, 8 Feb 2025 11:57:31 +0000 (12:57 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Sun, 9 Feb 2025 10:22:50 +0000 (11:22 +0100)
Index autogenerate will now render labels for expressions
that use them. This is useful when applying operator classes
in PostgreSQL that can be keyed on the label name.

Fixes: #1603
Change-Id: I187a944a5021e643e264ee1ec97807f6573e5f2f

alembic/autogenerate/render.py
alembic/ddl/postgresql.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/ops.py
docs/build/unreleased/1603.rst [new file with mode: 0644]
tests/test_autogen_render.py
tests/test_postgresql.py

index 6ebfbf9f3ffb0979be1ea0a8f6ae9a0cdf523587..7357c1d521bc1068d767b5e66827bab3f9a65d1f 100644 (file)
@@ -19,6 +19,7 @@ from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql
 from sqlalchemy import types as sqltypes
 from sqlalchemy.sql.elements import conv
+from sqlalchemy.sql.elements import Label
 from sqlalchemy.sql.elements import quoted_name
 
 from .. import util
@@ -584,23 +585,28 @@ def _render_potential_expr(
     value: Any,
     autogen_context: AutogenContext,
     *,
-    wrap_in_text: bool = True,
+    wrap_in_element: bool = True,
     is_server_default: bool = False,
     is_index: bool = False,
 ) -> str:
     if isinstance(value, sql.ClauseElement):
-        if wrap_in_text:
-            template = "%(prefix)stext(%(sql)r)"
+        sql_text = autogen_context.migration_context.impl.render_ddl_sql_expr(
+            value, is_server_default=is_server_default, is_index=is_index
+        )
+        if wrap_in_element:
+            prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
+            element = "literal_column" if is_index else "text"
+            value_str = f"{prefix}{element}({sql_text!r})"
+            if (
+                is_index
+                and isinstance(value, Label)
+                and type(value.name) is str
+            ):
+                return value_str + f".label({value.name!r})"
+            else:
+                return value_str
         else:
-            template = "%(sql)r"
-
-        return template % {
-            "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
-            "sql": autogen_context.migration_context.impl.render_ddl_sql_expr(
-                value, is_server_default=is_server_default, is_index=is_index
-            ),
-        }
-
+            return repr(sql_text)
     else:
         return repr(value)
 
@@ -787,7 +793,7 @@ def _render_computed(
     computed: Computed, autogen_context: AutogenContext
 ) -> str:
     text = _render_potential_expr(
-        computed.sqltext, autogen_context, wrap_in_text=False
+        computed.sqltext, autogen_context, wrap_in_element=False
     )
 
     kwargs = {}
@@ -1101,7 +1107,7 @@ def _render_check_constraint(
             else ""
         ),
         "sqltext": _render_potential_expr(
-            constraint.sqltext, autogen_context, wrap_in_text=False
+            constraint.sqltext, autogen_context, wrap_in_element=False
         ),
     }
 
index 60aa15366c2db1c71f24602f38202c914caa5ffe..2623308fdfc7bc84656851be867e261523a4e599 100644 (file)
@@ -846,5 +846,5 @@ def _render_potential_column(
         return render._render_potential_expr(
             value,
             autogen_context,
-            wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
+            wrap_in_element=isinstance(value, (TextClause, FunctionElement)),
         )
index 920444696ecf511e99216049072c8d96025f1610..d86bef4680ddc9bfe01ff586b5d9d7b8a8d0b184 100644 (file)
@@ -27,7 +27,6 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.elements import conv
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.expression import TableClause
-    from sqlalchemy.sql.functions import Function
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Computed
     from sqlalchemy.sql.schema import Identity
@@ -650,7 +649,7 @@ def create_foreign_key(
 def create_index(
     index_name: Optional[str],
     table_name: str,
-    columns: Sequence[Union[str, TextClause, Function[Any]]],
+    columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
     *,
     schema: Optional[str] = None,
     unique: bool = False,
index 9b52fa6f29e2acdeb63c7970f12f607419a3ae0b..456d1c75bb39eca98417bd491c5cdd50384ef440 100644 (file)
@@ -43,7 +43,6 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.expression import ColumnElement
     from sqlalchemy.sql.expression import TableClause
     from sqlalchemy.sql.expression import TextClause
-    from sqlalchemy.sql.functions import Function
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Computed
     from sqlalchemy.sql.schema import Identity
@@ -1074,7 +1073,7 @@ class Operations(AbstractOperations):
             self,
             index_name: Optional[str],
             table_name: str,
-            columns: Sequence[Union[str, TextClause, Function[Any]]],
+            columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
             *,
             schema: Optional[str] = None,
             unique: bool = False,
index 60b856a8f7664e436164450aa715620294d10d1b..bb4d825b1460adaf66e3bfb711e511b1f3f3877e 100644 (file)
@@ -35,7 +35,6 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.elements import conv
     from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.elements import TextClause
-    from sqlalchemy.sql.functions import Function
     from sqlalchemy.sql.schema import CheckConstraint
     from sqlalchemy.sql.schema import Column
     from sqlalchemy.sql.schema import Computed
@@ -933,7 +932,7 @@ class CreateIndexOp(MigrateOperation):
         operations: Operations,
         index_name: Optional[str],
         table_name: str,
-        columns: Sequence[Union[str, TextClause, Function[Any]]],
+        columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
         *,
         schema: Optional[str] = None,
         unique: bool = False,
diff --git a/docs/build/unreleased/1603.rst b/docs/build/unreleased/1603.rst
new file mode 100644 (file)
index 0000000..8393ecd
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, autogenerate
+    :tickets: 1603
+
+    Index autogenerate will now render labels for expressions
+    that use them. This is useful when applying operator classes
+    in PostgreSQL that can be keyed on the label name.
index b5c4e5727d1b6fa01fd21280ff34581fc0483933..f466da3c4ef65fe484cd313304a92700f501e939 100644 (file)
@@ -94,6 +94,32 @@ class AutogenRenderTest(TestBase):
             "['active', 'code'], unique=False)",
         )
 
+    def test_render_add_index_fn(self):
+        t = self.table(Column("other", String(100)))
+        idx = Index("test_fn_idx", t.c.code + t.c.other)
+        op_obj = ops.CreateIndexOp.from_index(idx)
+        eq_ignore_whitespace(
+            autogenerate.render_op_text(self.autogen_context, op_obj),
+            "op.create_index('test_fn_idx', 'test', "
+            "[sa.literal_column('code || other')], unique=False)",
+        )
+
+    def test_render_add_index_label(self):
+        t = self.table(Column("other", String(100)))
+        idx = Index(
+            "test_fn_idx",
+            (t.c.code + t.c.other).label("foo"),
+            t.c.id.label("bar"),
+        )
+        op_obj = ops.CreateIndexOp.from_index(idx)
+        eq_ignore_whitespace(
+            autogenerate.render_op_text(self.autogen_context, op_obj),
+            "op.create_index('test_fn_idx', 'test', ["
+            "sa.literal_column('code || other').label('foo'), "
+            "sa.literal_column('id').label('bar')"
+            "], unique=False)",
+        )
+
     def test_render_add_index_if_not_exists(self):
         """
         autogenerate.render._add_index
@@ -170,7 +196,7 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_active_code_idx', 'test', "
-            "['active', sa.text('lower(code)')], unique=False)",
+            "['active', sa.literal_column('lower(code)')], unique=False)",
         )
         op_obj_rev = op_obj.reverse()
         eq_ignore_whitespace(
@@ -186,7 +212,7 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_lower_code_idx', 'test', "
-            "[sa.text('lower(code)')], unique=False)",
+            "[sa.literal_column('lower(code)')], unique=False)",
         )
         op_obj_rev = op_obj.reverse()
         eq_ignore_whitespace(
@@ -202,7 +228,7 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_lower_code_idx', 'test', "
-            "[sa.text('CAST(code AS VARCHAR)')], unique=False)",
+            "[sa.literal_column('CAST(code AS VARCHAR)')], unique=False)",
         )
 
     def test_render_add_index_desc(self):
@@ -212,7 +238,7 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj),
             "op.create_index('test_desc_code_idx', 'test', "
-            "[sa.text('code DESC')], unique=False)",
+            "[sa.literal_column('code DESC')], unique=False)",
         )
 
     def test_drop_index(self):
@@ -256,7 +282,7 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj_rev),
             "op.create_index('test_active_code_idx', 'test', "
-            "['active', sa.text('lower(code)')], unique=False)",
+            "['active', sa.literal_column('lower(code)')], unique=False)",
         )
 
     def test_drop_index_func(self):
@@ -274,7 +300,7 @@ class AutogenRenderTest(TestBase):
         eq_ignore_whitespace(
             autogenerate.render_op_text(self.autogen_context, op_obj_rev),
             "op.create_index('test_lower_code_idx', 'test', "
-            "[sa.text('lower(code)')], unique=False)",
+            "[sa.literal_column('lower(code)')], unique=False)",
         )
 
     @testing.emits_warning("Can't validate argument ")
index e42ea9d3b2beb46d352596e4d8534a1d38958d3b..9eec5f2641cd064d1926bd1ee10cd77a3cc28893 100644 (file)
@@ -1330,7 +1330,7 @@ class PostgresqlAutogenRenderTest(TestBase):
                 ops.CreateIndexOp.from_index(idx),
             ),
             "op.create_index('my_idx', 'tbl', "
-            "[sa.text(\"(c ->> 'foo')\")], unique=False)",
+            "[sa.literal_column(\"(c ->> 'foo')\")], unique=False)",
         )
 
     @config.requirements.nulls_not_distinct_sa