From dba3ad87b5a762272a11450a18955468d019fe80 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 2 Aug 2010 16:59:06 -0400 Subject: [PATCH] - Specifying a non-column based argument for column_mapped_collection, including string, text() etc., will raise an error message that specifically asks for a column element, no longer misleads with incorrect information about text() or literal(). [ticket:1863] --- CHANGES | 7 +++++++ lib/sqlalchemy/orm/collections.py | 2 +- lib/sqlalchemy/sql/expression.py | 8 ++++++++ test/orm/test_collection.py | 23 +++++++++++++++++++++-- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/CHANGES b/CHANGES index af4fdb8d83..3d3bd724d5 100644 --- a/CHANGES +++ b/CHANGES @@ -47,6 +47,13 @@ CHANGES subclass attributes that "disable" propagation from the parent - these needed to allow a merge() operation to pass through without effect. + + - Specifying a non-column based argument + for column_mapped_collection, including string, + text() etc., will raise an error message that + specifically asks for a column element, no longer + misleads with incorrect information about + text() or literal(). [ticket:1863] - sql - Changed the scheme used to generate truncated diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 0ea17cd8bd..b5c4353b3b 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -129,7 +129,7 @@ def column_mapped_collection(mapping_spec): from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)] + cols = [expression._only_column_elements(q) for q in util.to_list(mapping_spec)] if len(cols) == 1: def keyfunc(value): state = instance_state(value) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 96147a94a0..8a92dba0dd 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1035,6 +1035,14 @@ def _no_literals(element): else: return element +def _only_column_elements(element): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, ColumnElement): + raise exc.ArgumentError("Column-based expression object expected; " + "got: %r" % element) + return element + def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 7dcc5c56f6..405829f743 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -7,12 +7,12 @@ from sqlalchemy.orm.collections import collection import sqlalchemy as sa from sqlalchemy.test import testing -from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy import Integer, String, ForeignKey, text from sqlalchemy.test.schema import Table, Column from sqlalchemy import util, exc as sa_exc from sqlalchemy.orm import create_session, mapper, relationship, attributes from test.orm import _base -from sqlalchemy.test.testing import eq_, assert_raises +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message class Canary(sa.orm.interfaces.AttributeExtension): def __init__(self): @@ -1561,6 +1561,25 @@ class DictHelpersTest(_base.MappedTest): eq_(Bar.foos.property.collection_class().keyfunc(Foo(id=3)), 3) eq_(Bar.foos2.property.collection_class().keyfunc(Foo(id=3, bar_id=12)), (3, 12)) + + @testing.resolve_artifact_names + def test_column_mapped_assertions(self): + assert_raises_message( + sa_exc.ArgumentError, + "Column-based expression object expected; got: 'a'", + collections.column_mapped_collection, "a", + ) + assert_raises_message( + sa_exc.ArgumentError, + "Column-based expression object expected; got", + collections.column_mapped_collection, text("a"), + ) + assert_raises_message( + sa_exc.ArgumentError, + "Column-based expression object expected; got", + collections.column_mapped_collection, text("a"), + ) + @testing.resolve_artifact_names def test_column_mapped_collection(self): -- 2.47.2