]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Some changes to how the :attr:`.FromClause.c` collection behaves
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Feb 2014 20:34:49 +0000 (15:34 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Feb 2014 20:34:49 +0000 (15:34 -0500)
when presented with duplicate columns.  The behavior of emitting a
warning and replacing the old column with the same name still
remains to some degree; the replacement in particular is to maintain
backwards compatibility.  However, the replaced column still remains
associated with the ``c`` collection now in a collection ``._all_columns``,
which is used by constructs such as aliases and unions, to deal with
the set of columns in ``c`` more towards what is actually in the
list of columns rather than the unique set of key names.  This helps
with situations where SELECT statements with same-named columns
are used in unions and such, so that the union can match the columns
up positionally and also there's some chance of :meth:`.FromClause.corresponding_column`
still being usable here (it can now return a column that is only
in selectable.c._all_columns and not otherwise named).
The new collection is underscored as we still need to decide where this
list might end up.   Theoretically it
would become the result of iter(selectable.c), however this would mean
that the length of the iteration would no longer match the length of
keys(), and that behavior needs to be checked out.
fixes #2974
- add a bunch more tests for ColumnCollection

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/selectable.py
test/base/test_utils.py
test/sql/test_selectable.py

index a6245bdb774bab6bd9402925472640dc967a779b..9a1ff9f277ee10199fef72f88f4d99a27a32437b 100644 (file)
 .. changelog::
     :version: 0.9.4
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 2974
+
+        Some changes to how the :attr:`.FromClause.c` collection behaves
+        when presented with duplicate columns.  The behavior of emitting a
+        warning and replacing the old column with the same name still
+        remains to some degree; the replacement in particular is to maintain
+        backwards compatibility.  However, the replaced column still remains
+        associated with the ``c`` collection now in a collection ``._all_columns``,
+        which is used by constructs such as aliases and unions, to deal with
+        the set of columns in ``c`` more towards what is actually in the
+        list of columns rather than the unique set of key names.  This helps
+        with situations where SELECT statements with same-named columns
+        are used in unions and such, so that the union can match the columns
+        up positionally and also there's some chance of :meth:`.FromClause.corresponding_column`
+        still being usable here (it can now return a column that is only
+        in selectable.c._all_columns and not otherwise named).
+        The new collection is underscored as we still need to decide where this
+        list might end up.   Theoretically it
+        would become the result of iter(selectable.c), however this would mean
+        that the length of the iteration would no longer match the length of
+        keys(), and that behavior needs to be checked out.
+
     .. change::
         :tags: bug, sql
 
index 260cdab660a180904dde952173b13a4228ec8a84..c2bdd8b1c3702d1316c1c78960a5831edf4c3bfe 100644 (file)
@@ -435,10 +435,10 @@ class ColumnCollection(util.OrderedProperties):
 
     """
 
-    def __init__(self, *cols):
+    def __init__(self):
         super(ColumnCollection, self).__init__()
-        self._data.update((c.key, c) for c in cols)
-        self.__dict__['_all_cols'] = util.column_set(self)
+        self.__dict__['_all_col_set'] = util.column_set()
+        self.__dict__['_all_columns'] = []
 
     def __str__(self):
         return repr([str(c) for c in self])
@@ -459,15 +459,26 @@ class ColumnCollection(util.OrderedProperties):
            Used by schema.Column to override columns during table reflection.
 
         """
+        remove_col = None
         if column.name in self and column.key != column.name:
             other = self[column.name]
             if other.name == other.key:
-                del self._data[other.name]
-                self._all_cols.remove(other)
+                remove_col = other
+                self._all_col_set.remove(other)
+                del self._data[other.key]
+
         if column.key in self._data:
-            self._all_cols.remove(self._data[column.key])
-        self._all_cols.add(column)
+            remove_col = self._data[column.key]
+            self._all_col_set.remove(remove_col)
+
+        self._all_col_set.add(column)
         self._data[column.key] = column
+        if remove_col is not None:
+            self._all_columns[:] = [column if c is remove_col
+                                            else c for c in self._all_columns]
+        else:
+            self._all_columns.append(column)
+
 
     def add(self, column):
         """Add a column to this collection.
@@ -497,37 +508,41 @@ class ColumnCollection(util.OrderedProperties):
                           '%r, which has the same key.  Consider '
                           'use_labels for select() statements.' % (key,
                           getattr(existing, 'table', None), value))
-            self._all_cols.remove(existing)
+
             # pop out memoized proxy_set as this
             # operation may very well be occurring
             # in a _make_proxy operation
             util.memoized_property.reset(value, "proxy_set")
-        self._all_cols.add(value)
+
+        self._all_col_set.add(value)
+        self._all_columns.append(value)
         self._data[key] = value
 
     def clear(self):
-        self._data.clear()
-        self._all_cols.clear()
+        raise NotImplementedError()
 
     def remove(self, column):
-        del self._data[column.key]
-        self._all_cols.remove(column)
+        raise NotImplementedError()
 
-    def update(self, value):
-        self._data.update(value)
-        self._all_cols.clear()
-        self._all_cols.update(self._data.values())
+    def update(self, iter):
+        cols = list(iter)
+        self._all_columns.extend(c for label, c in cols)
+        self._all_col_set.update(c for label, c in cols)
+        self._data.update((label, c) for label, c in cols)
 
     def extend(self, iter):
-        self.update((c.key, c) for c in iter)
+        cols = list(iter)
+        self._all_columns.extend(cols)
+        self._all_col_set.update(cols)
+        self._data.update((c.key, c) for c in cols)
 
     __hash__ = None
 
     @util.dependencies("sqlalchemy.sql.elements")
     def __eq__(self, elements, other):
         l = []
-        for c in other:
-            for local in self:
+        for c in getattr(other, "_all_columns", other):
+            for local in self._all_columns:
                 if c.shares_lineage(local):
                     l.append(c == local)
         return elements.and_(*l)
@@ -537,22 +552,28 @@ class ColumnCollection(util.OrderedProperties):
             raise exc.ArgumentError("__contains__ requires a string argument")
         return util.OrderedProperties.__contains__(self, other)
 
+    def __getstate__(self):
+        return {'_data': self.__dict__['_data'],
+                '_all_columns': self.__dict__['_all_columns']}
+
     def __setstate__(self, state):
         self.__dict__['_data'] = state['_data']
-        self.__dict__['_all_cols'] = util.column_set(self._data.values())
+        self.__dict__['_all_columns'] = state['_all_columns']
+        self.__dict__['_all_col_set'] = util.column_set(state['_all_columns'])
 
     def contains_column(self, col):
         # this has to be done via set() membership
-        return col in self._all_cols
+        return col in self._all_col_set
 
     def as_immutable(self):
-        return ImmutableColumnCollection(self._data, self._all_cols)
+        return ImmutableColumnCollection(self._data, self._all_col_set, self._all_columns)
 
 
 class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
-    def __init__(self, data, colset):
+    def __init__(self, data, colset, all_columns):
         util.ImmutableProperties.__init__(self, data)
-        self.__dict__['_all_cols'] = colset
+        self.__dict__['_all_col_set'] = colset
+        self.__dict__['_all_columns'] = all_columns
 
     extend = remove = util.ImmutableProperties._immutable
 
index 59d6687b54625d43115f7d932a65a2bfabf501d7..d59b45fae53623004e4b28d9b51b2336d5415cb8 100644 (file)
@@ -342,7 +342,7 @@ class FromClause(Selectable):
             return column
         col, intersect = None, None
         target_set = column.proxy_set
-        cols = self.c
+        cols = self.c._all_columns
         for c in cols:
             expanded_proxy_set = set(_expand_cloned(c.proxy_set))
             i = target_set.intersection(expanded_proxy_set)
@@ -934,6 +934,7 @@ class Alias(FromClause):
                     or 'anon'))
         self.name = name
 
+
     @property
     def description(self):
         if util.py3k:
@@ -954,7 +955,7 @@ class Alias(FromClause):
         return self.element.is_derived_from(fromclause)
 
     def _populate_column_collection(self):
-        for col in self.element.columns:
+        for col in self.element.columns._all_columns:
             col._make_proxy(self)
 
     def _refresh_for_new_column(self, column):
@@ -1738,13 +1739,13 @@ class CompoundSelect(GenerativeSelect):
             s = _clause_element_as_expr(s)
 
             if not numcols:
-                numcols = len(s.c)
-            elif len(s.c) != numcols:
+                numcols = len(s.c._all_columns)
+            elif len(s.c._all_columns) != numcols:
                 raise exc.ArgumentError('All selectables passed to '
                         'CompoundSelect must have identical numbers of '
                         'columns; select #%d has %d columns, select '
-                        '#%d has %d' % (1, len(self.selects[0].c), n
-                        + 1, len(s.c)))
+                        '#%d has %d' % (1, len(self.selects[0].c._all_columns), n
+                        + 1, len(s.c._all_columns)))
 
             self.selects.append(s.self_group(self))
 
@@ -1876,7 +1877,7 @@ class CompoundSelect(GenerativeSelect):
         return False
 
     def _populate_column_collection(self):
-        for cols in zip(*[s.c for s in self.selects]):
+        for cols in zip(*[s.c._all_columns for s in self.selects]):
 
             # this is a slightly hacky thing - the union exports a
             # column that resembles just that of the *first* selectable.
index 86e4b190a04de9488aafe3eefb3c698ddead2f38..e6ea062969340db0f9533affa1465de4b414ff41 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy.testing import assert_raises, assert_raises_message, fixtures
 from sqlalchemy.testing import eq_, is_, ne_, fails_if
 from sqlalchemy.testing.util import picklers, gc_collect
 from sqlalchemy.util import classproperty, WeakSequence, get_callable_argspec
-
+from sqlalchemy.sql import column
 
 class KeyedTupleTest():
 
@@ -298,6 +298,161 @@ class ColumnCollectionTest(fixtures.TestBase):
         assert (cc1 == cc2).compare(c1 == c2)
         assert not (cc1 == cc3).compare(c2 == c3)
 
+    @testing.emits_warning("Column ")
+    def test_dupes_add(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2a, c3, c2b = column('c1'), column('c2'), column('c3'), column('c2')
+
+        cc.add(c1)
+        cc.add(c2a)
+        cc.add(c3)
+        cc.add(c2b)
+
+        eq_(cc._all_columns, [c1, c2a, c3, c2b])
+
+        # for iter, c2a is replaced by c2b, ordering
+        # is maintained in that way.  ideally, iter would be
+        # the same as the "_all_columns" collection.
+        eq_(list(cc), [c1, c2b, c3])
+
+        assert cc.contains_column(c2a)
+        assert cc.contains_column(c2b)
+
+        ci = cc.as_immutable()
+        eq_(ci._all_columns, [c1, c2a, c3, c2b])
+        eq_(list(ci), [c1, c2b, c3])
+
+    def test_replace(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2a, c3, c2b = column('c1'), column('c2'), column('c3'), column('c2')
+
+        cc.add(c1)
+        cc.add(c2a)
+        cc.add(c3)
+
+        cc.replace(c2b)
+
+        eq_(cc._all_columns, [c1, c2b, c3])
+        eq_(list(cc), [c1, c2b, c3])
+
+        assert not cc.contains_column(c2a)
+        assert cc.contains_column(c2b)
+
+        ci = cc.as_immutable()
+        eq_(ci._all_columns, [c1, c2b, c3])
+        eq_(list(ci), [c1, c2b, c3])
+
+    def test_replace_key_matches(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2a, c3, c2b = column('c1'), column('c2'), column('c3'), column('X')
+        c2b.key = 'c2'
+
+        cc.add(c1)
+        cc.add(c2a)
+        cc.add(c3)
+
+        cc.replace(c2b)
+
+        assert not cc.contains_column(c2a)
+        assert cc.contains_column(c2b)
+
+        eq_(cc._all_columns, [c1, c2b, c3])
+        eq_(list(cc), [c1, c2b, c3])
+
+        ci = cc.as_immutable()
+        eq_(ci._all_columns, [c1, c2b, c3])
+        eq_(list(ci), [c1, c2b, c3])
+
+    def test_replace_name_matches(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2a, c3, c2b = column('c1'), column('c2'), column('c3'), column('c2')
+        c2b.key = 'X'
+
+        cc.add(c1)
+        cc.add(c2a)
+        cc.add(c3)
+
+        cc.replace(c2b)
+
+        assert not cc.contains_column(c2a)
+        assert cc.contains_column(c2b)
+
+        eq_(cc._all_columns, [c1, c2b, c3])
+        eq_(list(cc), [c1, c3, c2b])
+
+        ci = cc.as_immutable()
+        eq_(ci._all_columns, [c1, c2b, c3])
+        eq_(list(ci), [c1, c3, c2b])
+
+    def test_replace_no_match(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2, c3, c4 = column('c1'), column('c2'), column('c3'), column('c4')
+        c4.key = 'X'
+
+        cc.add(c1)
+        cc.add(c2)
+        cc.add(c3)
+
+        cc.replace(c4)
+
+        assert cc.contains_column(c2)
+        assert cc.contains_column(c4)
+
+        eq_(cc._all_columns, [c1, c2, c3, c4])
+        eq_(list(cc), [c1, c2, c3, c4])
+
+        ci = cc.as_immutable()
+        eq_(ci._all_columns, [c1, c2, c3, c4])
+        eq_(list(ci), [c1, c2, c3, c4])
+
+    def test_dupes_extend(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2a, c3, c2b = column('c1'), column('c2'), column('c3'), column('c2')
+
+        cc.add(c1)
+        cc.add(c2a)
+
+        cc.extend([c3, c2b])
+
+        eq_(cc._all_columns, [c1, c2a, c3, c2b])
+
+        # for iter, c2a is replaced by c2b, ordering
+        # is maintained in that way.  ideally, iter would be
+        # the same as the "_all_columns" collection.
+        eq_(list(cc), [c1, c2b, c3])
+
+        assert cc.contains_column(c2a)
+        assert cc.contains_column(c2b)
+
+        ci = cc.as_immutable()
+        eq_(ci._all_columns, [c1, c2a, c3, c2b])
+        eq_(list(ci), [c1, c2b, c3])
+
+    def test_dupes_update(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2a, c3, c2b = column('c1'), column('c2'), column('c3'), column('c2')
+
+        cc.add(c1)
+        cc.add(c2a)
+
+        cc.update([(c3.key, c3), (c2b.key, c2b)])
+
+        eq_(cc._all_columns, [c1, c2a, c3, c2b])
+
+        assert cc.contains_column(c2a)
+        assert cc.contains_column(c2b)
+
+        # for iter, c2a is replaced by c2b, ordering
+        # is maintained in that way.  ideally, iter would be
+        # the same as the "_all_columns" collection.
+        eq_(list(cc), [c1, c2b, c3])
 
 class LRUTest(fixtures.TestBase):
 
index 9617cfdf7b5bc120819cd6ccf7ea35cf239c6f88..5d3d53b8885e36b41c74d6c86117a0f87fd11ad2 100644 (file)
@@ -413,6 +413,41 @@ class SelectableTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
         assert u2.corresponding_column(s1.c.col1) is u2.c.col1
         assert u2.corresponding_column(s2.c.col1) is u2.c.col1
 
+    @testing.emits_warning("Column 'col1'")
+    def test_union_dupe_keys(self):
+        s1 = select([table1.c.col1, table1.c.col2, table2.c.col1])
+        s2 = select([table2.c.col1, table2.c.col2, table2.c.col3])
+        u1 = union(s1, s2)
+
+        assert u1.corresponding_column(s1.c._all_columns[0]) is u1.c._all_columns[0]
+        assert u1.corresponding_column(s2.c.col1) is u1.c._all_columns[0]
+        assert u1.corresponding_column(s1.c.col2) is u1.c.col2
+        assert u1.corresponding_column(s2.c.col2) is u1.c.col2
+
+        assert u1.corresponding_column(s2.c.col3) is u1.c._all_columns[2]
+
+        assert u1.corresponding_column(table2.c.col1) is u1.c._all_columns[2]
+        assert u1.corresponding_column(table2.c.col3) is u1.c._all_columns[2]
+
+    @testing.emits_warning("Column 'col1'")
+    def test_union_alias_dupe_keys(self):
+        s1 = select([table1.c.col1, table1.c.col2, table2.c.col1]).alias()
+        s2 = select([table2.c.col1, table2.c.col2, table2.c.col3])
+        u1 = union(s1, s2)
+
+        assert u1.corresponding_column(s1.c._all_columns[0]) is u1.c._all_columns[0]
+        assert u1.corresponding_column(s2.c.col1) is u1.c._all_columns[0]
+        assert u1.corresponding_column(s1.c.col2) is u1.c.col2
+        assert u1.corresponding_column(s2.c.col2) is u1.c.col2
+
+        assert u1.corresponding_column(s2.c.col3) is u1.c._all_columns[2]
+
+        # this differs from the non-alias test because table2.c.col1 is
+        # more directly at s2.c.col1 than it is s1.c.col1.
+        assert u1.corresponding_column(table2.c.col1) is u1.c._all_columns[0]
+        assert u1.corresponding_column(table2.c.col3) is u1.c._all_columns[2]
+
+
     def test_select_union(self):
 
         # like testaliasunion, but off a Select off the union.