]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jul 2007 15:43:32 +0000 (15:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jul 2007 15:43:32 +0000 (15:43 +0000)
semantics for "__contains__" [ticket:606]

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/base/utils.py

diff --git a/CHANGES b/CHANGES
index 668c94626c15aec4a2a19284d9f17e92944eeb60..f47e7be66f7a7866d439784c6a461bafac749f48 100644 (file)
--- a/CHANGES
+++ b/CHANGES
 - engines
   - Connections gain a .properties collection, with contents scoped to the
     lifetime of the underlying DBAPI connection
+  - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary
+    semantics for "__contains__" [ticket:606]
 - extensions
   - proxyengine is temporarily removed, pending an actually working
     replacement.
index a335cdd69c77d9523415d808db23dda22b348bfa..39853236049df15bf4293a54ab5cd143b2e91a38 100644 (file)
@@ -523,7 +523,7 @@ class PropertyLoader(StrategizedProperty):
             # load "polymorphic" versions of the columns present in "remote_side" - this is
             # important for lazy-clause generation which goes off the polymorphic target selectable
             for c in list(self.remote_side):
-                if self.secondary and c in self.secondary.columns:
+                if self.secondary and self.secondary.columns.contains_column(c):
                     continue
                 for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): 
                     corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False)
index 00b9cff68c09d9cb48d1958d65781a841c891776..5b392bdb82c8cfc617ef8de3de29233a20a88038 100644 (file)
@@ -871,7 +871,7 @@ class Constraint(SchemaItem):
         self.columns = sql.ColumnCollection()
 
     def __contains__(self, x):
-        return x in self.columns
+        return self.columns.contains_column(x)
 
     def keys(self):
         return self.columns.keys()
index 49fbb3aa03f51c2c06ccf734eb003c8c2044aec5..c463e1e99518954c8047a3dcfa115d186440a4f7 100644 (file)
@@ -1593,8 +1593,10 @@ class ColumnCollection(util.OrderedProperties):
                     l.append(c==local)
         return and_(*l)
 
-    def __contains__(self, col):
-        return self.contains_column(col)
+    def __contains__(self, other):
+        if not isinstance(other, basestring):
+            raise exceptions.ArgumentError("__contains__ requires a string argument")
+        return self.has_key(other)
         
     def contains_column(self, col):
         # have to use a Set here, because it will compare the identity
@@ -1714,7 +1716,7 @@ class FromClause(Selectable):
           the exported columns of this ``FromClause``.
         """
             
-        if column in self.c:
+        if self.c.contains_column(column):
             return column
 
         if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
index 96d3c96e432723a595102c34f92c7e35e46c81b1..97f3db06fcd871b50373239ad8a2f4d0c3fda406 100644 (file)
@@ -1,5 +1,5 @@
 import testbase
-from sqlalchemy import util
+from sqlalchemy import util, column, sql, exceptions
 from testlib import *
 
 
@@ -34,5 +34,34 @@ class OrderedDictTest(PersistTest):
         self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f'])
         self.assert_(o.values() == [1, 2, 3, 4, 5, 6])
 
+class ColumnCollectionTest(PersistTest):
+    def test_in(self):
+        cc = sql.ColumnCollection()
+        cc.add(column('col1'))
+        cc.add(column('col2'))
+        cc.add(column('col3'))
+        assert 'col1' in cc
+        assert 'col2' in cc
+
+        try:
+            cc['col1'] in cc
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "__contains__ requires a string argument"
+            
+    def test_compare(self):
+        cc1 = sql.ColumnCollection()
+        cc2 = sql.ColumnCollection()
+        cc3 = sql.ColumnCollection()
+        c1 = column('col1')
+        c2 = c1.label('col2')
+        c3 = column('col3')
+        cc1.add(c1)
+        cc2.add(c2)
+        cc3.add(c3)
+        assert (cc1==cc2).compare(c1 == c2)
+        assert not (cc1==cc3).compare(c2 == c3)
+        
+        
 if __name__ == "__main__":
     testbase.main()