]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Specifying a non-column based argument
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Aug 2010 20:59:06 +0000 (16:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Aug 2010 20:59:06 +0000 (16:59 -0400)
for column_mapped_collection, including string,
text() etc., will raise an error message that
specifically asks for a column element, no longer
misleads with incorrect information about
text() or literal().  [ticket:1863]

CHANGES
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/sql/expression.py
test/orm/test_collection.py

diff --git a/CHANGES b/CHANGES
index af4fdb8d830560dc53fb145b28bce08d48004607..3d3bd724d53613691271e25309d5f262c5e59b91 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -47,6 +47,13 @@ CHANGES
     subclass attributes that "disable" propagation from
     the parent - these needed to allow a merge()
     operation to pass through without effect.
+
+  - Specifying a non-column based argument
+    for column_mapped_collection, including string,
+    text() etc., will raise an error message that
+    specifically asks for a column element, no longer
+    misleads with incorrect information about
+    text() or literal().  [ticket:1863]
     
 - sql
   - Changed the scheme used to generate truncated
index 0ea17cd8bd72b448573d42d90304651c63e0f387..b5c4353b3b3f75f1c8d72055f85e508ec80a7d02 100644 (file)
@@ -129,7 +129,7 @@ def column_mapped_collection(mapping_spec):
     from sqlalchemy.orm.util import _state_mapper
     from sqlalchemy.orm.attributes import instance_state
 
-    cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)]
+    cols = [expression._only_column_elements(q) for q in util.to_list(mapping_spec)]
     if len(cols) == 1:
         def keyfunc(value):
             state = instance_state(value)
index 96147a94a0e8fc2589a0831d44d68cea7c675607..8a92dba0ddbd03d7b9ac4f950a0b195381dc304c 100644 (file)
@@ -1035,6 +1035,14 @@ def _no_literals(element):
     else:
         return element
 
+def _only_column_elements(element):
+    if hasattr(element, '__clause_element__'):
+        element = element.__clause_element__()
+    if not isinstance(element, ColumnElement):
+        raise exc.ArgumentError("Column-based expression object expected; "
+                                "got: %r" % element)
+    return element
+    
 def _corresponding_column_or_error(fromclause, column, require_embedded=False):
     c = fromclause.corresponding_column(column,
             require_embedded=require_embedded)
index 7dcc5c56f69d49c2a3a99c364c94423a28c7ddc4..405829f7434ff2dc23492eefba5e0be88bac9ba5 100644 (file)
@@ -7,12 +7,12 @@ from sqlalchemy.orm.collections import collection
 
 import sqlalchemy as sa
 from sqlalchemy.test import testing
-from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy import Integer, String, ForeignKey, text
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy import util, exc as sa_exc
 from sqlalchemy.orm import create_session, mapper, relationship, attributes
 from test.orm import _base
-from sqlalchemy.test.testing import eq_, assert_raises
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 
 class Canary(sa.orm.interfaces.AttributeExtension):
     def __init__(self):
@@ -1561,6 +1561,25 @@ class DictHelpersTest(_base.MappedTest):
             
         eq_(Bar.foos.property.collection_class().keyfunc(Foo(id=3)), 3)
         eq_(Bar.foos2.property.collection_class().keyfunc(Foo(id=3, bar_id=12)), (3, 12))
+
+    @testing.resolve_artifact_names
+    def test_column_mapped_assertions(self):
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Column-based expression object expected; got: 'a'",
+            collections.column_mapped_collection, "a",
+        )
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Column-based expression object expected; got",
+            collections.column_mapped_collection, text("a"),
+        )
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Column-based expression object expected; got",
+            collections.column_mapped_collection, text("a"),
+        )
+        
         
     @testing.resolve_artifact_names
     def test_column_mapped_collection(self):