From ac9eb5c9c3bc33c38eff5407fa4724c9277ba342 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 21 Aug 2010 19:38:28 -0400 Subject: [PATCH] - Similarly, for relationship(), foreign_keys, remote_side, order_by - all column-based expressions are enforced - lists of strings are explicitly disallowed since this is a very common error --- CHANGES | 6 ++++++ lib/sqlalchemy/orm/collections.py | 3 ++- lib/sqlalchemy/orm/properties.py | 11 ++++++----- lib/sqlalchemy/sql/expression.py | 6 +++--- test/orm/test_collection.py | 27 ++++++++++++--------------- test/orm/test_relationships.py | 30 ++++++++++++++++++++++++++++++ 6 files changed, 59 insertions(+), 24 deletions(-) diff --git a/CHANGES b/CHANGES index aae1139a14..e099f68ae0 100644 --- a/CHANGES +++ b/CHANGES @@ -87,6 +87,12 @@ CHANGES misleads with incorrect information about text() or literal(). [ticket:1863] + - Similarly, for relationship(), foreign_keys, + remote_side, order_by - all column-based + expressions are enforced - lists of strings + are explicitly disallowed since this is a + very common error + - Dynamic attributes don't support collection population - added an assertion for when set_committed_value() is called, as well as diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index b5c4353b3b..a9ad342390 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -129,7 +129,8 @@ def column_mapped_collection(mapping_spec): from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._only_column_elements(q) for q in util.to_list(mapping_spec)] + cols = [expression._only_column_elements(q, "mapping_spec") + for q in util.to_list(mapping_spec)] if len(cols) == 1: def keyfunc(value): state = instance_state(value) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 5788c30f94..7e19d7b16d 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -926,16 +926,17 @@ class RelationshipProperty(StrategizedProperty): for attr in 'primaryjoin', 'secondaryjoin': val = getattr(self, attr) if val is not None: - util.assert_arg_type(val, sql.ColumnElement, attr) - setattr(self, attr, _orm_deannotate(val)) + setattr(self, attr, _orm_deannotate( + expression._only_column_elements(val, attr)) + ) if self.order_by is not False and self.order_by is not None: - self.order_by = [expression._literal_as_column(x) for x in + self.order_by = [expression._only_column_elements(x, "order_by") for x in util.to_list(self.order_by)] self._user_defined_foreign_keys = \ - util.column_set(expression._literal_as_column(x) for x in + util.column_set(expression._only_column_elements(x, "foreign_keys") for x in util.to_column_set(self._user_defined_foreign_keys)) self.remote_side = \ - util.column_set(expression._literal_as_column(x) for x in + util.column_set(expression._only_column_elements(x, "remote_side") for x in util.to_column_set(self.remote_side)) if not self.parent.concrete: for inheriting in self.parent.iterate_to_root(): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a7f5d396a2..6f593ab484 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1037,12 +1037,12 @@ def _no_literals(element): else: return element -def _only_column_elements(element): +def _only_column_elements(element, name): if hasattr(element, '__clause_element__'): element = element.__clause_element__() if not isinstance(element, ColumnElement): - raise exc.ArgumentError("Column-based expression object expected; " - "got: %r" % element) + raise exc.ArgumentError("Column-based expression object expected for argument '%s'; " + "got: '%s', type %s" % (name, element, type(element))) return element def _corresponding_column_or_error(fromclause, column, diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 405829f743..a33d2d6d13 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -1564,21 +1564,18 @@ class DictHelpersTest(_base.MappedTest): @testing.resolve_artifact_names def test_column_mapped_assertions(self): - assert_raises_message( - sa_exc.ArgumentError, - "Column-based expression object expected; got: 'a'", - collections.column_mapped_collection, "a", - ) - assert_raises_message( - sa_exc.ArgumentError, - "Column-based expression object expected; got", - collections.column_mapped_collection, text("a"), - ) - assert_raises_message( - sa_exc.ArgumentError, - "Column-based expression object expected; got", - collections.column_mapped_collection, text("a"), - ) + assert_raises_message(sa_exc.ArgumentError, + "Column-based expression object expected " + "for argument 'mapping_spec'; got: 'a', " + "type ", + collections.column_mapped_collection, 'a') + assert_raises_message(sa_exc.ArgumentError, + "Column-based expression object expected " + "for argument 'mapping_spec'; got: 'a', " + "type ", + collections.column_mapped_collection, + text('a')) @testing.resolve_artifact_names diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index ce315ff350..187c9e5349 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -970,6 +970,36 @@ class JoinConditionErrorTest(testing.TestBase): mapper(C2, t2) assert_raises(sa.exc.ArgumentError, compile_mappers) + def test_invalid_string_args(self): + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import util + + for argname, arg in [ + ('remote_side', ['c1.id']), + ('remote_side', ['id']), + ('foreign_keys', ['c1id']), + ('foreign_keys', ['C2.c1id']), + ('order_by', ['id']), + ]: + clear_mappers() + kw = {argname:arg} + Base = declarative_base() + class C1(Base): + __tablename__ = 'c1' + id = Column('id', Integer, primary_key=True) + + class C2(Base): + __tablename__ = 'c2' + id_ = Column('id', Integer, primary_key=True) + c1id = Column('c1id', Integer, ForeignKey('c1.id')) + c2 = relationship(C1, **kw) + + assert_raises_message( + sa.exc.ArgumentError, + "Column-based expression object expected for argument '%s'; got: '%s', type %r" % (argname, arg[0], type(arg[0])), + compile_mappers) + + def test_fk_error_raised(self): m = MetaData() t1 = Table('t1', m, -- 2.47.2