]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fully copy index expressions
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Dec 2017 18:58:58 +0000 (13:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Dec 2017 18:58:58 +0000 (13:58 -0500)
Fixed bug where the :meth:`.Table.tometadata` method would not properly
accommodate :class:`.Index` objects that didn't consist of simple
column expressions, such as indexes against a :func:`.text` construct,
indexes that used SQL expressions or :attr:`.func`, etc.   The routine
now copies expressions fully to a new :class:`.Index` object while
substituting all table-bound :class:`.Column` objects for those
of the target table.

Also refined the means by which tometadata() checks if an Index
or UniqueConstraint is generated by a column-level flag, by propagating
an attribute "_column_flag=True" to such indexes/constraints.

Change-Id: I7ef1b8ea42f9933357ae35f241a5ba9838bac35b
Fixes: #4147
doc/build/changelog/unreleased_12/4147.rst [new file with mode: 0644]
lib/sqlalchemy/sql/schema.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_12/4147.rst b/doc/build/changelog/unreleased_12/4147.rst
new file mode 100644 (file)
index 0000000..5369e4b
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4147
+
+    Fixed bug where the :meth:`.Table.tometadata` method would not properly
+    accommodate :class:`.Index` objects that didn't consist of simple
+    column expressions, such as indexes against a :func:`.text` construct,
+    indexes that used SQL expressions or :attr:`.func`, etc.   The routine
+    now copies expressions fully to a new :class:`.Index` object while
+    substituting all table-bound :class:`.Column` objects for those
+    of the target table.
\ No newline at end of file
index 683da823f263137c126d830e5d7640bee25a063b..dcca01e56e26243d7e7e4acc658b0deceb801edd 100644 (file)
@@ -65,6 +65,17 @@ def _get_table_key(name, schema):
         return schema + "." + name
 
 
+# this should really be in sql/util.py but we'd have to
+# break an import cycle
+def _copy_expression(expression, source_table, target_table):
+    def replace(col):
+        if source_table.c.contains_column(col):
+            return target_table.c[col.key]
+        else:
+            return None
+    return visitors.replacement_traverse(expression, {}, replace)
+
+
 @inspection._self_inspects
 class SchemaItem(SchemaEventTarget, visitors.Visitable):
     """Base class for items that define a database schema."""
@@ -882,9 +893,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
             elif not c._type_bound:
                 # skip unique constraints that would be generated
                 # by the 'unique' flag on Column
-                if isinstance(c, UniqueConstraint) and \
-                    len(c.columns) == 1 and \
-                        list(c.columns)[0].unique:
+                if c._column_flag:
                     continue
 
                 table.append_constraint(
@@ -892,12 +901,13 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
         for index in self.indexes:
             # skip indexes that would be generated
             # by the 'index' flag on Column
-            if len(index.columns) == 1 and \
-                    list(index.columns)[0].index:
+            if index._column_flag:
                 continue
             Index(index.name,
                   unique=index.unique,
-                  *[table.c[col] for col in index.columns.keys()],
+                  *[_copy_expression(expr, self, table)
+                    for expr in index.expressions],
+                  _table=table,
                   **index.kwargs)
         return self._schema_item_copy(table)
 
@@ -1372,7 +1382,7 @@ class Column(SchemaItem, ColumnClause):
                     "The 'index' keyword argument on Column is boolean only. "
                     "To create indexes with a specific name, create an "
                     "explicit Index object external to the Table.")
-            Index(None, self, unique=bool(self.unique))
+            Index(None, self, unique=bool(self.unique), _column_flag=True)
         elif self.unique:
             if isinstance(self.unique, util.string_types):
                 raise exc.ArgumentError(
@@ -1381,7 +1391,8 @@ class Column(SchemaItem, ColumnClause):
                     "specific name, append an explicit UniqueConstraint to "
                     "the Table's list of elements, or create an explicit "
                     "Index object external to the Table.")
-            table.append_constraint(UniqueConstraint(self.key))
+            table.append_constraint(
+                UniqueConstraint(self.key, _column_flag=True))
 
         self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table))
 
@@ -2576,6 +2587,7 @@ class ColumnCollectionMixin(object):
 
     def __init__(self, *columns, **kw):
         _autoattach = kw.pop('_autoattach', True)
+        self._column_flag = kw.pop('_column_flag', False)
         self.columns = ColumnCollection()
         self._pending_colargs = [_to_schema_column_or_string(c)
                                  for c in columns]
@@ -2681,8 +2693,10 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
 
         """
         _autoattach = kw.pop('_autoattach', True)
+        _column_flag = kw.pop('_column_flag', False)
         Constraint.__init__(self, **kw)
-        ColumnCollectionMixin.__init__(self, *columns, _autoattach=_autoattach)
+        ColumnCollectionMixin.__init__(
+            self, *columns, _autoattach=_autoattach, _column_flag=_column_flag)
 
     columns = None
     """A :class:`.ColumnCollection` representing the set of columns
@@ -2785,12 +2799,8 @@ class CheckConstraint(ColumnCollectionConstraint):
 
     def copy(self, target_table=None, **kw):
         if target_table is not None:
-            def replace(col):
-                if self.table.c.contains_column(col):
-                    return target_table.c[col.key]
-                else:
-                    return None
-            sqltext = visitors.replacement_traverse(self.sqltext, {}, replace)
+            sqltext = _copy_expression(
+                self.sqltext, self.table, target_table)
         else:
             sqltext = self.sqltext
         c = CheckConstraint(sqltext,
@@ -3417,6 +3427,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
         self.expressions = processed_expressions
         self.name = quoted_name(name, kw.pop("quote", None))
         self.unique = kw.pop('unique', False)
+        _column_flag = kw.pop('_column_flag', False)
         if 'info' in kw:
             self.info = kw.pop('info')
 
@@ -3429,7 +3440,8 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
 
         # will call _set_parent() if table-bound column
         # objects are present
-        ColumnCollectionMixin.__init__(self, *columns)
+        ColumnCollectionMixin.__init__(
+            self, *columns, _column_flag=_column_flag)
 
         if table is not None:
             self._set_parent(table)
index 45eb594534f00ed36f178863a85c6e74ad1792ec..7d17e25f34ac0177f7806ea937c4560c3ac96147 100644 (file)
@@ -1047,8 +1047,11 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables):
                       Column('id', Integer, primary_key=True),
                       Column('data1', Integer, index=True),
                       Column('data2', Integer),
+                      Index('text', text('data1 + 1')),
                       )
-        Index('multi', table.c.data1, table.c.data2),
+        Index('multi', table.c.data1, table.c.data2)
+        Index('func', func.abs(table.c.data1))
+        Index('multi-func', table.c.data1, func.abs(table.c.data2))
 
         meta2 = MetaData()
         table_c = table.tometadata(meta2)
@@ -1056,7 +1059,7 @@ class ToMetaDataTest(fixtures.TestBase, ComparesTables):
         def _get_key(i):
             return [i.name, i.unique] + \
                 sorted(i.kwargs.items()) + \
-                list(i.columns.keys())
+                [str(col) for col in i.expressions]
 
         eq_(
             sorted([_get_key(i) for i in table.indexes]),