From: Mike Bayer Date: Mon, 23 Jul 2012 22:22:06 +0000 (-0400) Subject: - [feature] ORM entities can be passed X-Git-Tag: rel_0_8_0b1~305 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=65bdf245c6cfd4381f8463714fbec1880a950fbb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - [feature] ORM entities can be passed to select() as well as the select_from(), correlate(), and correlate_except() methods, where they will be unwrapped into selectables. [ticket:2245] --- diff --git a/CHANGES b/CHANGES index ac0303c9c3..db71d310b1 100644 --- a/CHANGES +++ b/CHANGES @@ -151,6 +151,12 @@ underneath "0.7.xx". need autoflush w pre-attached object. [ticket:2464] + - [feature] ORM entities can be passed + to select() as well as the select_from(), + correlate(), and correlate_except() + methods, where they will be unwrapped + into selectables. [ticket:2245] + - [feature] The registry of classes in declarative_base() is now a WeakValueDictionary. So subclasses of diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 47349e64a8..045a9465dc 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -112,9 +112,11 @@ PASSIVE_ONLY_PERSISTENT = util.symbol("PASSIVE_ONLY_PERSISTENT", ) -class QueryableAttribute(interfaces.PropComparator): +class QueryableAttribute(interfaces._InspectionAttr, interfaces.PropComparator): """Base class for class-bound attributes. """ + is_attribute = True + def __init__(self, class_, key, impl=None, comparator=None, parententity=None, of_type=None): @@ -149,6 +151,10 @@ class QueryableAttribute(interfaces.PropComparator): # TODO: conditionally attach this method based on clause_element ? return self + @property + def expression(self): + return self.comparator.__clause_element__() + def __clause_element__(self): return self.comparator.__clause_element__() @@ -191,10 +197,7 @@ class QueryableAttribute(interfaces.PropComparator): def property(self): return self.comparator.property - -@inspection._inspects(QueryableAttribute) -def _get_prop(source): - return source.property +inspection._self_inspects(QueryableAttribute) class InstrumentedAttribute(QueryableAttribute): """Class bound instrumented attribute which adds descriptor methods.""" diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 84c75525a3..d0732b9135 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -62,6 +62,8 @@ class _InspectionAttr(object): is_instance = False is_mapper = False is_property = False + is_attribute = False + is_clause_element = False class MapperProperty(_InspectionAttr): """Manage the relationship of a ``Mapper`` to a single class diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4533bbdb06..53ee1b5fd3 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -102,6 +102,7 @@ class ColumnProperty(StrategizedProperty): else: self.strategy_class = strategies.ColumnLoader + @property def expression(self): """Return the primary column or expression for this ColumnProperty. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 35c70d51ed..0a345f2841 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -26,10 +26,11 @@ from . import ( ) from .util import ( AliasedClass, ORMAdapter, _entity_descriptor, PathRegistry, - _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable, + _is_aliased_class, _is_mapped_class, _orm_columns, join as orm_join,with_parent, aliased ) -from .. import sql, util, log, exc as sa_exc, inspect +from .. import sql, util, log, exc as sa_exc, inspect, inspection +from ..sql.expression import _interpret_as_from from ..sql import ( util as sql_util, expression, visitors @@ -539,6 +540,9 @@ class Query(object): return self.enable_eagerloads(False).statement.as_scalar() + @property + def selectable(self): + return self.__clause_element__() def __clause_element__(self): return self.enable_eagerloads(False).with_labels().statement @@ -798,7 +802,8 @@ class Query(object): """ self._correlate = self._correlate.union( - _orm_selectable(s) + _interpret_as_from(s) + if s is not None else None for s in args) @_generative() @@ -2672,7 +2677,6 @@ class Query(object): statement.append_order_by(*context.eager_order_by) return statement - def _adjust_for_single_inheritance(self, context): """Apply single-table-inheritance filtering. @@ -2696,6 +2700,7 @@ class Query(object): def __str__(self): return str(self._compile_context().statement) +inspection._self_inspects(Query) class _QueryEntity(object): """represent an entity column returned within a Query result.""" diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 8f340d3667..27d9b1b69f 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1074,13 +1074,6 @@ def _orm_columns(entity): else: return [entity] -def _orm_selectable(entity): - insp = inspection.inspect(entity, False) - if hasattr(insp, 'selectable'): - return insp.selectable - else: - return entity - def has_identity(object): state = attributes.instance_state(object) return state.has_identity diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b9c149954a..a518852d89 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1426,6 +1426,19 @@ def _literal_as_text(element): "SQL expression object or string expected." ) +def _interpret_as_from(element): + insp = inspection.inspect(element, raiseerr=False) + if insp is None: + if isinstance(element, (util.NoneType, bool)): + return _const_expr(element) + elif isinstance(element, basestring): + return TextClause(unicode(element)) + elif hasattr(insp, "selectable"): + return insp.selectable + else: + raise exc.ArgumentError("FROM expression expected") + + def _const_expr(element): if element is None: return null() @@ -1445,12 +1458,15 @@ def _clause_element_as_expr(element): return element def _literal_as_column(element): - if isinstance(element, Visitable): - return element - elif hasattr(element, '__clause_element__'): - return element.__clause_element__() - else: - return literal_column(str(element)) + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + if hasattr(insp, "expression"): + return insp.expression + elif hasattr(insp, "selectable"): + return insp.selectable + elif insp.is_clause_element: + return insp + return literal_column(str(element)) def _literal_as_binds(element, name=None, type_=None): if hasattr(element, '__clause_element__'): @@ -1539,6 +1555,7 @@ class ClauseElement(Visitable): bind = None _is_clone_of = None is_selectable = False + is_clause_element = True def _clone(self): """Create a shallow copy of this ClauseElement. @@ -2173,6 +2190,15 @@ class ColumnElement(ClauseElement, CompareMixin): _key_label = None _alt_names = () + @property + def expression(self): + """Return a column expression. + + Part of the inspection interface; returns self. + + """ + return self + @property def _select_iterable(self): return (self, ) @@ -2973,6 +2999,10 @@ class TextClause(Executable, ClauseElement): def _select_iterable(self): return (self,) + @property + def selectable(self): + return self + _hide_froms = [] def __init__( @@ -5315,7 +5345,8 @@ class Select(SelectBase): if fromclauses and fromclauses[0] is None: self._correlate = () else: - self._correlate = set(self._correlate).union(fromclauses) + self._correlate = set(self._correlate).union( + _interpret_as_from(f) for f in fromclauses) @_generative def correlate_except(self, *fromclauses): @@ -5323,15 +5354,16 @@ class Select(SelectBase): if fromclauses and fromclauses[0] is None: self._correlate_except = () else: - self._correlate_except = set(self._correlate_except - ).union(fromclauses) + self._correlate_except = set(self._correlate_except).union( + _interpret_as_from(f) for f in fromclauses) def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" self._should_correlate = False - self._correlate = set(self._correlate).union([fromclause]) + self._correlate = set(self._correlate).union( + _interpret_as_from(f) for f in fromclause) def append_column(self, column): """append the given column expression to the columns clause of this @@ -5387,7 +5419,7 @@ class Select(SelectBase): """ self._reset_exported() - fromclause = _literal_as_text(fromclause) + fromclause = _interpret_as_from(fromclause) self._from_obj = self._from_obj.union([fromclause]) def _populate_column_collection(self): diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index f6037d0719..546c662047 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -65,7 +65,6 @@ class QueryTest(_fixtures.FixtureTest): configure_mappers() - class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): query_correlated = "SELECT users.name AS users_name, " \ diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index 487182fac2..1c5cab8a03 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -178,17 +178,23 @@ class TestORMInspection(_fixtures.FixtureTest): def test_insp_prop(self): User = self.classes.User prop = inspect(User.addresses) - is_(prop, User.addresses.property) + is_(prop, User.addresses) + + def test_insp_aliased_prop(self): + User = self.classes.User + ua = aliased(User) + prop = inspect(ua.addresses) + is_(prop, ua.addresses) def test_rel_accessors(self): User = self.classes.User Address = self.classes.Address prop = inspect(User.addresses) - is_(prop.parent, class_mapper(User)) - is_(prop.mapper, class_mapper(Address)) + is_(prop.property.parent, class_mapper(User)) + is_(prop.property.mapper, class_mapper(Address)) assert not hasattr(prop, 'columns') - assert not hasattr(prop, 'expression') + assert hasattr(prop, 'expression') def test_instance_state(self): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index e8d6c0901d..b80db67ebc 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -119,6 +119,80 @@ class RowTupleTest(QueryTest): asserted ) +class RawSelectTest(QueryTest, AssertsCompiledSQL): + __dialect__ = 'default' + + def test_select_from_entity(self): + User = self.classes.User + + self.assert_compile( + select(['*']).select_from(User), + "SELECT * FROM users" + ) + + def test_select_from_aliased_entity(self): + User = self.classes.User + ua = aliased(User, name="ua") + self.assert_compile( + select(['*']).select_from(ua), + "SELECT * FROM users AS ua" + ) + + def test_correlate_entity(self): + User = self.classes.User + Address = self.classes.Address + + self.assert_compile( + select([User]).where(User.id == Address.user_id). + correlate(Address), + "SELECT users.id, users.name FROM users " + "WHERE users.id = addresses.user_id" + ) + + def test_correlate_aliased_entity(self): + User = self.classes.User + Address = self.classes.Address + aa = aliased(Address, name="aa") + + self.assert_compile( + select([User]).where(User.id == aa.user_id). + correlate(aa), + "SELECT users.id, users.name FROM users " + "WHERE users.id = aa.user_id" + ) + + def test_columns_clause_entity(self): + User = self.classes.User + + self.assert_compile( + select([User]), + "SELECT users.id, users.name FROM users" + ) + + def test_columns_clause_columns(self): + User = self.classes.User + + self.assert_compile( + select([User.id, User.name]), + "SELECT users.id, users.name FROM users" + ) + + def test_columns_clause_aliased_columns(self): + User = self.classes.User + ua = aliased(User, name='ua') + self.assert_compile( + select([ua.id, ua.name]), + "SELECT ua.id, ua.name FROM users AS ua" + ) + + def test_columns_clause_aliased_entity(self): + User = self.classes.User + ua = aliased(User, name='ua') + self.assert_compile( + select([ua]), + "SELECT ua.id, ua.name FROM users AS ua" + ) + class GetTest(QueryTest): def test_get(self): User = self.classes.User