From: Mike Bayer Date: Thu, 27 Feb 2014 18:50:47 +0000 (-0500) Subject: restore the contracts of update/extend to the degree that the same column identity X-Git-Tag: rel_0_9_4~99 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c2f86c92b1fbb4e855161bd509d3057f86ed7a74;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git restore the contracts of update/extend to the degree that the same column identity isn't appended to the list. reflection makes use of this. --- diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index c2bdd8b1c3..379f61ed79 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -522,17 +522,19 @@ class ColumnCollection(util.OrderedProperties): raise NotImplementedError() def remove(self, column): - raise NotImplementedError() + del self._data[column.key] + self._all_col_set.remove(column) + self._all_columns[:] = [c for c in self._all_columns if c is not column] def update(self, iter): cols = list(iter) - self._all_columns.extend(c for label, c in cols) + self._all_columns.extend(c for label, c in cols if c not in self._all_col_set) self._all_col_set.update(c for label, c in cols) self._data.update((label, c) for label, c in cols) def extend(self, iter): cols = list(iter) - self._all_columns.extend(cols) + self._all_columns.extend(c for c in cols if c not in self._all_col_set) self._all_col_set.update(cols) self._data.update((c.key, c) for c in cols) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index e6ea062969..4ff17e8cc8 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -454,6 +454,36 @@ class ColumnCollectionTest(fixtures.TestBase): # the same as the "_all_columns" collection. eq_(list(cc), [c1, c2b, c3]) + def test_extend_existing(self): + cc = sql.ColumnCollection() + + c1, c2, c3, c4, c5 = column('c1'), column('c2'), column('c3'), column('c4'), column('c5') + + cc.extend([c1, c2]) + eq_(cc._all_columns, [c1, c2]) + + cc.extend([c3]) + eq_(cc._all_columns, [c1, c2, c3]) + cc.extend([c4, c2, c5]) + + eq_(cc._all_columns, [c1, c2, c3, c4, c5]) + + def test_update_existing(self): + cc = sql.ColumnCollection() + + c1, c2, c3, c4, c5 = column('c1'), column('c2'), column('c3'), column('c4'), column('c5') + + cc.update([('c1', c1), ('c2', c2)]) + eq_(cc._all_columns, [c1, c2]) + + cc.update([('c3', c3)]) + eq_(cc._all_columns, [c1, c2, c3]) + cc.update([('c4', c4), ('c2', c2), ('c5', c5)]) + + eq_(cc._all_columns, [c1, c2, c3, c4, c5]) + + + class LRUTest(fixtures.TestBase): def test_lru(self):