]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Similarly, for relationship(), foreign_keys,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 21 Aug 2010 23:38:28 +0000 (19:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 21 Aug 2010 23:38:28 +0000 (19:38 -0400)
remote_side, order_by - all column-based
expressions are enforced - lists of strings
are explicitly disallowed since this is a
very common error

CHANGES
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/expression.py
test/orm/test_collection.py
test/orm/test_relationships.py

diff --git a/CHANGES b/CHANGES
index aae1139a14ba16e9d7b390f7721b53c14b696a5c..e099f68ae093997b33e826e28b3fc5dda20dd3f7 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -87,6 +87,12 @@ CHANGES
     misleads with incorrect information about
     text() or literal().  [ticket:1863]
 
+  - Similarly, for relationship(), foreign_keys,
+    remote_side, order_by - all column-based 
+    expressions are enforced - lists of strings
+    are explicitly disallowed since this is a
+    very common error
+    
   - Dynamic attributes don't support collection 
     population - added an assertion for when
     set_committed_value() is called, as well as
index b5c4353b3b3f75f1c8d72055f85e508ec80a7d02..a9ad342390c3c4d6cdb3fd31c6a4dbd3ceb03ef5 100644 (file)
@@ -129,7 +129,8 @@ def column_mapped_collection(mapping_spec):
     from sqlalchemy.orm.util import _state_mapper
     from sqlalchemy.orm.attributes import instance_state
 
-    cols = [expression._only_column_elements(q) for q in util.to_list(mapping_spec)]
+    cols = [expression._only_column_elements(q, "mapping_spec") 
+                for q in util.to_list(mapping_spec)]
     if len(cols) == 1:
         def keyfunc(value):
             state = instance_state(value)
index 5788c30f948a0b142154205d5276c4ada84f4f23..7e19d7b16d9fbf89389e9efb6eb7622da7cefa90 100644 (file)
@@ -926,16 +926,17 @@ class RelationshipProperty(StrategizedProperty):
         for attr in 'primaryjoin', 'secondaryjoin':
             val = getattr(self, attr)
             if val is not None:
-                util.assert_arg_type(val, sql.ColumnElement, attr)
-                setattr(self, attr, _orm_deannotate(val))
+                setattr(self, attr, _orm_deannotate(
+                    expression._only_column_elements(val, attr))
+                )
         if self.order_by is not False and self.order_by is not None:
-            self.order_by = [expression._literal_as_column(x) for x in
+            self.order_by = [expression._only_column_elements(x, "order_by") for x in
                              util.to_list(self.order_by)]
         self._user_defined_foreign_keys = \
-            util.column_set(expression._literal_as_column(x) for x in
+            util.column_set(expression._only_column_elements(x, "foreign_keys") for x in
                             util.to_column_set(self._user_defined_foreign_keys))
         self.remote_side = \
-            util.column_set(expression._literal_as_column(x) for x in
+            util.column_set(expression._only_column_elements(x, "remote_side") for x in
                             util.to_column_set(self.remote_side))
         if not self.parent.concrete:
             for inheriting in self.parent.iterate_to_root():
index a7f5d396a213480ec7085bef96477e438cccf86c..6f593ab484131fbb756731fe9f0ad71cc52d3bec 100644 (file)
@@ -1037,12 +1037,12 @@ def _no_literals(element):
     else:
         return element
 
-def _only_column_elements(element):
+def _only_column_elements(element, name):
     if hasattr(element, '__clause_element__'):
         element = element.__clause_element__()
     if not isinstance(element, ColumnElement):
-        raise exc.ArgumentError("Column-based expression object expected; "
-                                "got: %r" % element)
+        raise exc.ArgumentError("Column-based expression object expected for argument '%s'; "
+                                "got: '%s', type %s" % (name, element, type(element)))
     return element
     
 def _corresponding_column_or_error(fromclause, column,
index 405829f7434ff2dc23492eefba5e0be88bac9ba5..a33d2d6d13b1260d007a1a0e372d74d53f35eba0 100644 (file)
@@ -1564,21 +1564,18 @@ class DictHelpersTest(_base.MappedTest):
 
     @testing.resolve_artifact_names
     def test_column_mapped_assertions(self):
-        assert_raises_message(
-            sa_exc.ArgumentError,
-            "Column-based expression object expected; got: 'a'",
-            collections.column_mapped_collection, "a",
-        )
-        assert_raises_message(
-            sa_exc.ArgumentError,
-            "Column-based expression object expected; got",
-            collections.column_mapped_collection, text("a"),
-        )
-        assert_raises_message(
-            sa_exc.ArgumentError,
-            "Column-based expression object expected; got",
-            collections.column_mapped_collection, text("a"),
-        )
+        assert_raises_message(sa_exc.ArgumentError,
+                              "Column-based expression object expected "
+                              "for argument 'mapping_spec'; got: 'a', "
+                              "type <type 'str'>",
+                              collections.column_mapped_collection, 'a')
+        assert_raises_message(sa_exc.ArgumentError,
+                              "Column-based expression object expected "
+                              "for argument 'mapping_spec'; got: 'a', "
+                              "type <class 'sqlalchemy.sql.expression._"
+                              "TextClause'>",
+                              collections.column_mapped_collection,
+                              text('a'))
         
         
     @testing.resolve_artifact_names
index ce315ff35093cf7ba2aaada6bd06715302cafb35..187c9e5349535357bde3806f8b8723150d5f7c4a 100644 (file)
@@ -970,6 +970,36 @@ class JoinConditionErrorTest(testing.TestBase):
         mapper(C2, t2)
         assert_raises(sa.exc.ArgumentError, compile_mappers)
     
+    def test_invalid_string_args(self):
+        from sqlalchemy.ext.declarative import declarative_base
+        from sqlalchemy import util
+        
+        for argname, arg in [
+            ('remote_side', ['c1.id']),
+            ('remote_side', ['id']),
+            ('foreign_keys', ['c1id']),
+            ('foreign_keys', ['C2.c1id']),
+            ('order_by', ['id']),
+        ]:
+            clear_mappers()
+            kw = {argname:arg}
+            Base = declarative_base()
+            class C1(Base):
+                __tablename__ = 'c1'
+                id = Column('id', Integer, primary_key=True)
+            
+            class C2(Base):
+                __tablename__ = 'c2'
+                id_ = Column('id', Integer, primary_key=True)
+                c1id = Column('c1id', Integer, ForeignKey('c1.id'))
+                c2 = relationship(C1, **kw)
+            
+            assert_raises_message(
+                sa.exc.ArgumentError, 
+                "Column-based expression object expected for argument '%s'; got: '%s', type %r" % (argname, arg[0], type(arg[0])),
+                compile_mappers)
+        
+    
     def test_fk_error_raised(self):
         m = MetaData()
         t1 = Table('t1', m,