]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restore the contracts of update/extend to the degree that the same column identity
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Feb 2014 18:50:47 +0000 (13:50 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Feb 2014 18:50:47 +0000 (13:50 -0500)
isn't appended to the list.  reflection makes use of this.

lib/sqlalchemy/sql/base.py
test/base/test_utils.py

index c2bdd8b1c3702d1316c1c78960a5831edf4c3bfe..379f61ed79f0c26a6b16f49ba782e41a6efeaef3 100644 (file)
@@ -522,17 +522,19 @@ class ColumnCollection(util.OrderedProperties):
         raise NotImplementedError()
 
     def remove(self, column):
-        raise NotImplementedError()
+        del self._data[column.key]
+        self._all_col_set.remove(column)
+        self._all_columns[:] = [c for c in self._all_columns if c is not column]
 
     def update(self, iter):
         cols = list(iter)
-        self._all_columns.extend(c for label, c in cols)
+        self._all_columns.extend(c for label, c in cols if c not in self._all_col_set)
         self._all_col_set.update(c for label, c in cols)
         self._data.update((label, c) for label, c in cols)
 
     def extend(self, iter):
         cols = list(iter)
-        self._all_columns.extend(cols)
+        self._all_columns.extend(c for c in cols if c not in self._all_col_set)
         self._all_col_set.update(cols)
         self._data.update((c.key, c) for c in cols)
 
index e6ea062969340db0f9533affa1465de4b414ff41..4ff17e8cc8f6c3c9bba661b3f2a75f083f71c237 100644 (file)
@@ -454,6 +454,36 @@ class ColumnCollectionTest(fixtures.TestBase):
         # the same as the "_all_columns" collection.
         eq_(list(cc), [c1, c2b, c3])
 
+    def test_extend_existing(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2, c3, c4, c5 = column('c1'), column('c2'), column('c3'), column('c4'), column('c5')
+
+        cc.extend([c1, c2])
+        eq_(cc._all_columns, [c1, c2])
+
+        cc.extend([c3])
+        eq_(cc._all_columns, [c1, c2, c3])
+        cc.extend([c4, c2, c5])
+
+        eq_(cc._all_columns, [c1, c2, c3, c4, c5])
+
+    def test_update_existing(self):
+        cc = sql.ColumnCollection()
+
+        c1, c2, c3, c4, c5 = column('c1'), column('c2'), column('c3'), column('c4'), column('c5')
+
+        cc.update([('c1', c1), ('c2', c2)])
+        eq_(cc._all_columns, [c1, c2])
+
+        cc.update([('c3', c3)])
+        eq_(cc._all_columns, [c1, c2, c3])
+        cc.update([('c4', c4), ('c2', c2), ('c5', c5)])
+
+        eq_(cc._all_columns, [c1, c2, c3, c4, c5])
+
+
+
 class LRUTest(fixtures.TestBase):
 
     def test_lru(self):