from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
+from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
- def _cleanup_index_expr(self, index: Index, expr: str) -> str:
+ def _cleanup_index_expr(
+ self, index: Index, expr: str, remove_suffix: str
+ ) -> str:
# start = expr
expr = expr.lower()
expr = expr.replace('"', "")
if index.table is not None:
+ # should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
+ if remove_suffix and expr.endswith(remove_suffix):
+ expr = expr[: -len(remove_suffix)]
+
# print(f"START: {start} END: {expr}")
return expr
+ def _default_modifiers(self, exp: ClauseElement) -> str:
+ to_remove = ""
+ while isinstance(exp, UnaryExpression):
+ if exp.modifier is None:
+ exp = exp.element
+ else:
+ op = exp.modifier
+ if isinstance(exp.element, UnaryExpression):
+ inner_op = exp.element.modifier
+ else:
+ inner_op = None
+ if inner_op is None:
+ if op == operators.asc_op:
+ # default is asc
+ to_remove = " asc"
+ elif op == operators.nullslast_op:
+ # default is nulls last
+ to_remove = " nulls last"
+ else:
+ if (
+ inner_op == operators.asc_op
+ and op == operators.nullslast_op
+ ):
+ # default is asc nulls last
+ to_remove = " asc nulls last"
+ elif (
+ inner_op == operators.desc_op
+ and op == operators.nullsfirst_op
+ ):
+ # default for desc is nulls first
+ to_remove = " nulls first"
+ break
+ return to_remove
+
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
- if sqla_compat.is_expression_index(index):
- return tuple(
- self._cleanup_index_expr(
- index,
- e
+ return tuple(
+ self._cleanup_index_expr(
+ index,
+ *(
+ (e, "")
if isinstance(e, str)
- else e.compile(
- dialect=self.dialect,
- compile_kwargs={"literal_binds": True},
- ).string,
- )
- for e in index.expressions
+ else (self._compile_element(e), self._default_modifiers(e))
+ ),
)
- else:
- return super().create_index_sig(index)
+ for e in index.expressions
+ )
+
+ def _compile_element(self, element: ClauseElement) -> str:
+ return element.compile(
+ dialect=self.dialect,
+ compile_kwargs={"literal_binds": True, "include_table": False},
+ ).string
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
from sqlalchemy import Table
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION
+from sqlalchemy.sql.expression import asc
from sqlalchemy.sql.expression import column
from sqlalchemy.sql.expression import desc
+from sqlalchemy.sql.expression import nullsfirst
+from sqlalchemy.sql.expression import nullslast
from alembic import testing
from alembic.testing import combinations
eq_(diffs, [])
- # fails in the 0.8 series where we have truncation rules,
- # but no control over quoting. passes in 0.7.9 where we don't have
- # truncation rules either. dropping these ancient versions
- # is long overdue.
-
def test_unchanged_case_sensitive_implicit_idx(self):
m1 = MetaData()
m2 = MetaData()
],
)
+ @config.requirements.reflects_indexes_column_sorting
+ @testing.combinations(
+ (desc, asc),
+ (asc, desc),
+ (desc, lambda x: nullslast(desc(x))),
+ (nullslast, nullsfirst),
+ (nullsfirst, nullslast),
+ (lambda x: nullslast(desc(x)), lambda x: nullsfirst(asc(x))),
+ )
+ def test_column_sort_changed(self, old_fn, new_fn):
+ m1 = MetaData()
+ m2 = MetaData()
+
+ old = Index("SomeIndex", old_fn("y"))
+ Table("order_change", m1, Column("y", Integer), old)
+
+ new = Index("SomeIndex", new_fn("y"))
+ Table("order_change", m2, Column("y", Integer), new)
+ diffs = self._fixture(m1, m2)
+ eq_(
+ diffs,
+ [
+ (
+ "remove_index",
+ schemacompare.CompareIndex(old, name_only=True),
+ ),
+ ("add_index", schemacompare.CompareIndex(new, name_only=True)),
+ ],
+ )
+
+ @config.requirements.reflects_indexes_column_sorting
+ @testing.combinations(
+ (asc, asc),
+ (desc, desc),
+ (nullslast, nullslast),
+ (nullsfirst, nullsfirst),
+ (lambda x: x, asc),
+ (lambda x: x, nullslast),
+ (desc, lambda x: nullsfirst(desc(x))),
+ (lambda x: nullslast(asc(x)), lambda x: x),
+ )
+ def test_column_sort_not_changed(self, old_fn, new_fn):
+ m1 = MetaData()
+ m2 = MetaData()
+
+ old = Index("SomeIndex", old_fn("y"))
+ Table("order_change", m1, Column("y", Integer), old)
+
+ new = Index("SomeIndex", new_fn("y"))
+ Table("order_change", m2, Column("y", Integer), new)
+ diffs = self._fixture(m1, m2)
+ eq_(diffs, [])
+
+
+def _lots_of_indexes(flatten: bool = False):
+ diff_pairs = [
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", func.lower(t.c.x)),
+ ),
+ (
+ lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
+ lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
+ ),
+ (
+ lambda t: Index(
+ "SomeIndex", "y", func.lower(column("x")), _table=t
+ ),
+ lambda t: Index("SomeIndex", func.lower(column("x")), _table=t),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
+ lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
+ lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
+ ),
+ (
+ lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
+ lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
+ ),
+ (
+ lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
+ lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.ff + 42),
+ lambda t: Index("SomeIndex", 42 + t.c.ff),
+ ),
+ ]
+
+ with_sort = [
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", desc("y"), func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", nullsfirst(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", nullsfirst("y"), func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+ lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+ lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", nullslast(asc(func.lower(t.c.x)))),
+ lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+ ),
+ (
+ lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+ lambda t: Index("SomeIndex", nullsfirst(desc(func.lower(t.c.x)))),
+ ),
+ (
+ lambda t: Index("SomeIndex", nullsfirst(func.lower(t.c.x))),
+ lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+ ),
+ ]
+
+ req = config.requirements.reflects_indexes_column_sorting
+
+ if flatten:
+
+ flat = list(itertools.chain.from_iterable(diff_pairs))
+ for f1, f2 in with_sort:
+ flat.extend([(f1, req), (f2, req)])
+ return flat
+ else:
+ return diff_pairs + [(f1, f2, req) for f1, f2 in with_sort]
+
+
+def _lost_of_equal_indexes(_lots_of_indexes):
+ equal_pairs = [
+ (fn, fn) if not isinstance(fn, tuple) else (fn[0], fn[0], fn[1])
+ for fn in _lots_of_indexes(flatten=True)
+ ]
+ equal_pairs += [
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", asc(func.lower(t.c.x))),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", nullslast(func.lower(t.c.x))),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index(
+ "SomeIndex", "y", nullslast(asc(func.lower(t.c.x)))
+ ),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+ lambda t: Index(
+ "SomeIndex", "y", nullsfirst(desc(func.lower(t.c.x)))
+ ),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ ]
+ return equal_pairs
+
class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
"""tests involving indexes with expression"""
diffs = self._fixture(m1, m2)
eq_(diffs, [])
- def _lots_of_indexes(flatten: bool = False):
- diff_pairs = [
- (
- lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
- lambda t: Index("SomeIndex", func.lower(t.c.x)),
- ),
- (
- lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
- lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
- ),
- (
- lambda t: Index(
- "SomeIndex", "y", func.lower(column("x")), _table=t
- ),
- lambda t: Index(
- "SomeIndex", func.lower(column("x")), _table=t
- ),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- ),
- (
- lambda t: Index("SomeIndex", func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
- lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
- lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
- ),
- (
- lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
- lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
- ),
- (
- lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
- lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.ff + 42),
- lambda t: Index("SomeIndex", 42 + t.c.ff),
- ),
- ]
- if flatten:
- return list(itertools.chain.from_iterable(diff_pairs))
- else:
- return diff_pairs
-
@testing.fixture
def index_changed_tables(self):
m1 = MetaData()
diffs = self._fixture(m1, m2)
eq_(diffs, [])
- @combinations(*_lots_of_indexes(flatten=True), argnames="fn")
- def test_expression_indexes_no_change(self, index_changed_tables, fn):
+ @combinations(
+ *_lost_of_equal_indexes(_lots_of_indexes), argnames="fn1, fn2"
+ )
+ def test_expression_indexes_no_change(
+ self, index_changed_tables, fn1, fn2
+ ):
m1, m2, old_fixture_tables, new_fixture_tables = index_changed_tables
- resolve_lambda(fn, **old_fixture_tables)
- resolve_lambda(fn, **new_fixture_tables)
+ resolve_lambda(fn1, **old_fixture_tables)
+ resolve_lambda(fn2, **new_fixture_tables)
if self.has_reflection:
ctx = nullcontext()