]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug preventing declarative-bound "column" objects
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Nov 2008 18:20:53 +0000 (18:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Nov 2008 18:20:53 +0000 (18:20 +0000)
from being used in column_mapped_collection().  [ticket:1174]

CHANGES
lib/sqlalchemy/orm/collections.py
test/orm/collection.py

diff --git a/CHANGES b/CHANGES
index 51d0e1aadd28d2a16f9b8a65f9136657cf8d02e3..d09c8276e79a52af79c7d19743e83aebb2f1186b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -146,6 +146,10 @@ CHANGES
     - No longer expects include_columns in table reflection to be
       lower case.
 
+- ext
+    - Fixed bug preventing declarative-bound "column" objects 
+      from being used in column_mapped_collection().  [ticket:1174]
+
 - misc
     - util.flatten_iterator() func doesn't interpret strings with
       __iter__() methods as iterators, such as in pypy [ticket:1077].
index 497ef5941162d31b74088d8f7e4f293b4083e3ef..2105a4fe6a3c672e1b92f5b4e8f4aa5ff413e049 100644 (file)
@@ -104,6 +104,7 @@ import sys
 import weakref
 
 import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.sql import expression
 from sqlalchemy import schema
 import sqlalchemy.util as sautil
 
@@ -130,18 +131,13 @@ def column_mapped_collection(mapping_spec):
     from sqlalchemy.orm.util import _state_mapper
     from sqlalchemy.orm.attributes import instance_state
 
-    if isinstance(mapping_spec, schema.Column):
+    cols = [expression._no_literals(q) for q in sautil.to_list(mapping_spec)]
+    if len(cols) == 1:
         def keyfunc(value):
             state = instance_state(value)
             m = _state_mapper(state)
-            return m._get_state_attr_by_column(state, mapping_spec)
+            return m._get_state_attr_by_column(state, cols[0])
     else:
-        cols = []
-        for c in mapping_spec:
-            if not isinstance(c, schema.Column):
-                raise sa_exc.ArgumentError(
-                    "mapping_spec tuple may only contain columns")
-            cols.append(c)
         mapping_spec = tuple(cols)
         def keyfunc(value):
             state = instance_state(value)
index 0d858487334113e7313220c36a6e33d4a52ccdfc..c37a20b681be8b6d0d4b7ff808edc2e2a6d03b36 100644 (file)
@@ -11,7 +11,7 @@ from testlib.sa import util, exc as sa_exc
 from testlib.sa.orm import create_session, mapper, relation, \
     attributes
 from orm import _base
-
+from testlib.testing import eq_
 
 class Canary(sa.orm.interfaces.AttributeExtension):
     def __init__(self):
@@ -1369,7 +1369,8 @@ class DictHelpersTest(_base.MappedTest):
 
         p = session.query(Parent).get(pid)
 
-        self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
+        
+        self.assertEquals(set(p.children.keys()), set(['foo', 'bar']))
         cid = p.children['foo'].id
 
         collections.collection_adapter(p.children).append_with_event(
@@ -1457,6 +1458,27 @@ class DictHelpersTest(_base.MappedTest):
         collection_class = collections.attribute_mapped_collection('a')
         self._test_scalar_mapped(collection_class)
 
+    def test_declarative_column_mapped(self):
+        """test that uncompiled attribute usage works with column_mapped_collection"""
+        
+        from sqlalchemy.ext.declarative import declarative_base
+
+        BaseObject = declarative_base()
+
+        class Foo(BaseObject):
+            __tablename__ = "foo"
+            id = Column(Integer(), primary_key=True)
+            bar_id = Column(Integer, ForeignKey('bar.id'))
+            
+        class Bar(BaseObject):
+            __tablename__ = "bar"
+            id = Column(Integer(), primary_key=True)
+            foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id))
+            foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id)))
+            
+        eq_(Bar.foos.property.collection_class().keyfunc(Foo(id=3)), 3)
+        eq_(Bar.foos2.property.collection_class().keyfunc(Foo(id=3, bar_id=12)), (3, 12))
+        
     @testing.resolve_artifact_names
     def test_column_mapped_collection(self):
         collection_class = collections.column_mapped_collection(