From: Mike Bayer Date: Mon, 2 Jun 2008 03:07:12 +0000 (+0000) Subject: - removed query.min()/max()/sum()/avg(). these should be called using column argumen... X-Git-Tag: rel_0_5beta1~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e525aee01556e59ff9fc02dd68fd6a38532fe45a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - removed query.min()/max()/sum()/avg(). these should be called using column arguments or values in conjunction with func. - fixed [ticket:1008], count() works with single table inheritance - changed the relationship of InstrumentedAttribute to class such that each subclass in an inheritance hierarchy gets a unique InstrumentedAttribute per column-oriented attribute, including for the same underlying ColumnProperty. This allows expressions from subclasses to be annotated accurately so that Query can get a hold of the exact entities to be queried when using column-based expressions. This repairs various polymorphic scenarios with both single and joined table inheritance. - still to be determined is what does something like query(Person.name, Engineer.engineer_info) do; currently it's problematic. Even trickier is query(Person.name, Engineer.engineer_info, Manager.manager_name) --- diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6c9fe77533..7b120e884f 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -439,9 +439,10 @@ class PropComparator(expression.ColumnOperators): return a.has(b, **kwargs) has_op = staticmethod(has_op) - def __init__(self, prop): + def __init__(self, prop, mapper): self.prop = self.property = prop - + self.mapper = mapper + def of_type_op(a, class_): return a.of_type(class_) of_type_op = staticmethod(of_type_op) @@ -753,11 +754,12 @@ class LoaderStrategy(object): def __init__(self, parent): self.parent_property = parent self.is_class_level = False - - def init(self): self.parent = self.parent_property.parent self.key = self.parent_property.key + def init(self): + raise NotImplementedError("LoaderStrategy") + def init_class_attribute(self): pass diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c7300f2160..792824fe3f 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -38,7 +38,7 @@ class ColumnProperty(StrategizedProperty): self.columns = [expression._labeled(c) for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) - self.comparator = ColumnProperty.ColumnComparator(self) + self.comparator_factory = ColumnProperty.ColumnComparator util.set_creation_order(self) if self.deferred: self.strategy_class = strategies.DeferredColumnLoader @@ -80,7 +80,7 @@ class ColumnProperty(StrategizedProperty): class ColumnComparator(PropComparator): def __clause_element__(self): - return self.prop.columns[0]._annotate({"parententity": self.prop.parent}) + return self.prop.columns[0]._annotate({"parententity": self.mapper}) __clause_element__ = util.cache_decorator(__clause_element__) def operate(self, op, *other, **kwargs): @@ -101,7 +101,7 @@ class CompositeProperty(ColumnProperty): def __init__(self, class_, *columns, **kwargs): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ - self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self) + self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator) self.strategy_class = strategies.CompositeColumnLoader def do_init(self): @@ -170,8 +170,7 @@ class SynonymProperty(MapperProperty): def do_init(self): class_ = self.parent.class_ - def comparator(): - return self.parent._get_property(self.key, resolve_synonyms=True).comparator + self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__)) if self.descriptor is None: class SynonymProp(object): @@ -184,7 +183,14 @@ class SynonymProperty(MapperProperty): return s return getattr(obj, self.name) self.descriptor = SynonymProp() - sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent) + + def comparator_callable(prop, mapper): + def comparator(): + prop = self.parent._get_property(self.key, resolve_synonyms=True) + return prop.comparator_factory(prop, mapper) + return comparator + + strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, comparator_callable, proxy_property=self.descriptor) def merge(self, session, source, dest, _recursive): pass @@ -195,18 +201,13 @@ class ComparableProperty(MapperProperty): def __init__(self, comparator_factory, descriptor=None): self.descriptor = descriptor - self.comparator = comparator_factory(self) + self.comparator_factory = comparator_factory util.set_creation_order(self) def do_init(self): """Set up a proxy to the unmanaged descriptor.""" - class_ = self.parent.class_ - # refactor me - sessionlib.register_attribute(class_, self.key, uselist=False, - proxy_property=self.descriptor, - useobject=False, - comparator=self.comparator) + strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, self.comparator_factory, proxy_property=self.descriptor) def setup(self, context, entity, path, adapter, **kwargs): pass @@ -252,10 +253,11 @@ class PropertyLoader(StrategizedProperty): self.passive_updates = passive_updates self.remote_side = remote_side self.enable_typechecks = enable_typechecks - self.comparator = PropertyLoader.Comparator(self) + self.comparator = PropertyLoader.Comparator(self, None) self.join_depth = join_depth self.local_remote_pairs = _local_remote_pairs self.__join_cache = {} + self.comparator_factory = PropertyLoader.Comparator util.set_creation_order(self) if strategy_class: @@ -295,8 +297,9 @@ class PropertyLoader(StrategizedProperty): self._is_backref = _is_backref class Comparator(PropComparator): - def __init__(self, prop, of_type=None): + def __init__(self, prop, mapper, of_type=None): self.prop = self.property = prop + self.mapper = mapper if of_type: self._of_type = _class_to_mapper(of_type) @@ -314,7 +317,7 @@ class PropertyLoader(StrategizedProperty): return op(self, *other, **kwargs) def of_type(self, cls): - return PropertyLoader.Comparator(self.prop, cls) + return PropertyLoader.Comparator(self.prop, self.mapper, cls) def __eq__(self, other): if other is None: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 555c376e54..43f206f38a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -97,7 +97,14 @@ class Query(object): self.__setup_aliasizers(self._entities) def __setup_aliasizers(self, entities): - d = {} + if hasattr(self, '_mapper_adapter_map'): + # usually safe to share a single map, but copying to prevent + # subtle leaks if end-user is reusing base query with arbitrary + # number of aliased() objects + self._mapper_adapter_map = d = self._mapper_adapter_map.copy() + else: + self._mapper_adapter_map = d = {} + for ent in entities: for entity in ent.entities: if entity not in d: @@ -114,7 +121,7 @@ class Query(object): d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic) ent.setup_entity(entity, *d[entity]) - + def __mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: for m in m2.iterate_to_root(): @@ -650,26 +657,6 @@ class Query(object): return self.filter(sql.and_(*clauses)) - def min(self, col): - """Execute the SQL ``min()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.min) - - def max(self, col): - """Execute the SQL ``max()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.max) - - def sum(self, col): - """Execute the SQL ``sum()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.sum) - - def avg(self, col): - """Execute the SQL ``avg()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.avg) - def order_by(self, *criterion): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" @@ -1213,18 +1200,17 @@ class Query(object): _should_nest_selectable = property(_should_nest_selectable) def count(self): - """Apply this query's criterion to a SELECT COUNT statement. - - this is the purely generative version which will become - the public method in version 0.5. - - """ - return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key)) + """Apply this query's criterion to a SELECT COUNT statement.""" + + return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key)) def _col_aggregate(self, col, func, nested_cols=None): - whereclause = self._criterion - context = QueryContext(self) + + self._adjust_for_single_inheritance(context) + + whereclause = context.whereclause + from_obj = self.__mapper_zero_from_obj() if self._should_nest_selectable: @@ -1371,7 +1357,9 @@ class Query(object): froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used else: froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM - + + self._adjust_for_single_inheritance(context) + if eager_joins and self._should_nest_selectable: # for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select, # then append eager joins onto that @@ -1382,7 +1370,15 @@ class Query(object): context.order_by = None order_by_col_expr = [] - inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args) + inner = sql.select( + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=froms, + use_labels=labels, + correlate=False, + order_by=context.order_by, + **self._select_args + ) if self._correlate: inner = inner.correlate(*self._correlate) @@ -1418,7 +1414,17 @@ class Query(object): froms += context.eager_joins.values() - statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args) + statement = sql.select( + context.primary_columns + context.secondary_columns, + context.whereclause, + from_obj=froms, + use_labels=labels, + for_update=for_update, + correlate=False, + order_by=context.order_by, + **self._select_args + ) + if self._correlate: statement = statement.correlate(*self._correlate) @@ -1429,6 +1435,22 @@ class Query(object): return context + def _adjust_for_single_inheritance(self, context): + """Apply single-table-inheritance filtering. + + For all distinct single-table-inheritance mappers represented in the columns + clause of this query, add criterion to the WHERE clause of the given QueryContext + such that only the appropriate subtypes are selected from the total results. + + """ + for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems(): + if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None: + crit = mapper.polymorphic_on.in_([m.polymorphic_identity for m in mapper.polymorphic_iterator()]) + if adapter: + crit = adapter.traverse(crit) + crit = self._adapt_clause(crit, False, False) + context.whereclause = sql.and_(context.whereclause, crit) + def __log_debug(self, msg): self.logger.debug(msg) @@ -1463,7 +1485,7 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.entity_zero = entity self.entity_name = entity_name - + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): self.mapper = mapper self.extension = self.mapper.extension @@ -1554,15 +1576,10 @@ class _MapperEntity(_QueryEntity): return main, entname def setup_context(self, query, context): - # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so - # that we only load the appropriate types - if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: - context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()])) + adapter = self._get_entity_clauses(query, context) context.froms.append(self.selectable) - adapter = self._get_entity_clauses(query, context) - if context.order_by is False and self.mapper.order_by: context.order_by = self.mapper.order_by diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d40937a18c..829210205e 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -17,11 +17,31 @@ from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -class ColumnLoader(LoaderStrategy): - """Default column loader.""" +class DefaultColumnLoader(LoaderStrategy): + def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None): + self.logger.info("%s register managed attribute" % self) + + for mapper in self.parent.polymorphic_iterator(): + if mapper is self.parent or not mapper.concrete: + sessionlib.register_attribute( + mapper.class_, + self.key, + uselist=False, + useobject=False, + copy_function=copy_function, + compare_function=compare_function, + mutable_scalars=mutable_scalars, + comparator=comparator_factory(self.parent_property, mapper), + parententity=mapper, + callable_=callable_, + proxy_property=proxy_property + ) + +DefaultColumnLoader.logger = log.class_logger(DefaultColumnLoader) + +class ColumnLoader(DefaultColumnLoader): def init(self): - super(ColumnLoader, self).init() self.columns = self.parent_property.columns self._should_log_debug = log.is_debug_enabled(self.logger) self.is_composite = hasattr(self.parent_property, 'composite_class') @@ -34,9 +54,14 @@ class ColumnLoader(LoaderStrategy): def init_class_attribute(self): self.is_class_level = True - self.logger.info("%s register managed attribute" % self) coltype = self.columns[0].type - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent) + + self._register_attribute( + coltype.compare_values, + coltype.copy_value, + self.columns[0].type.is_mutable(), + self.parent_property.comparator_factory + ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): key, col = self.key, self.columns[0] @@ -78,7 +103,13 @@ class CompositeColumnLoader(ColumnLoader): return False else: return True - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent) + + self._register_attribute( + compare, + copy, + True, + self.parent_property.comparator_factory + ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class @@ -106,7 +137,7 @@ class CompositeColumnLoader(ColumnLoader): CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader) -class DeferredColumnLoader(LoaderStrategy): +class DeferredColumnLoader(DefaultColumnLoader): """Deferred column loader, a per-column or per-column-group lazy loader.""" def create_row_processor(self, selectcontext, path, mapper, row, adapter): @@ -130,7 +161,6 @@ class DeferredColumnLoader(LoaderStrategy): return (new_execute, None) def init(self): - super(DeferredColumnLoader, self).init() if hasattr(self.parent_property, 'composite_class'): raise NotImplementedError("Deferred loading for composite types not implemented yet") self.columns = self.parent_property.columns @@ -139,8 +169,13 @@ class DeferredColumnLoader(LoaderStrategy): def init_class_attribute(self): self.is_class_level = True - self.logger.info("%s register managed attribute" % self) - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent) + self._register_attribute( + self.columns[0].type.compare_values, + self.columns[0].type.copy_value, + self.columns[0].type.is_mutable(), + self.parent_property.comparator_factory, + callable_=self.class_level_loader, + ) def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs): if \ @@ -238,7 +273,6 @@ class UndeferGroupOption(MapperOption): class AbstractRelationLoader(LoaderStrategy): def init(self): - super(AbstractRelationLoader, self).init() for attr in ['mapper', 'target', 'table', 'uselist']: setattr(self, attr, getattr(self.parent_property, attr)) self._should_log_debug = log.is_debug_enabled(self.logger) @@ -249,7 +283,7 @@ class AbstractRelationLoader(LoaderStrategy): else: state.initialize(self.key) - def _register_attribute(self, class_, callable_=None, **kwargs): + def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs): self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar"))) if self.parent_property.backref: @@ -257,7 +291,21 @@ class AbstractRelationLoader(LoaderStrategy): else: attribute_ext = None - sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs) + sessionlib.register_attribute( + class_, + self.key, + uselist=self.uselist, + useobject=True, + extension=attribute_ext, + cascade=self.parent_property.cascade, + trackparent=True, + typecallable=self.parent_property.collection_class, + callable_=callable_, + comparator=self.parent_property.comparator, + parententity=self.parent, + impl_class=impl_class, + **kwargs + ) class NoLoader(AbstractRelationLoader): def init_class_attribute(self): diff --git a/test/orm/generative.py b/test/orm/generative.py index 652823b28e..dd27092029 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -1,6 +1,6 @@ import testenv; testenv.configure_for_tests() from testlib import testing, sa -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData +from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func from sqlalchemy.orm import mapper, relation, create_session from testlib.testing import eq_ from testlib.compat import set @@ -57,8 +57,9 @@ class GenerativeQueryTest(_base.MappedTest): sess = create_session() query = sess.query(Foo) assert query.count() == 100 - assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0 - assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29 + assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,) + + assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,) assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29 assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29 @@ -68,14 +69,14 @@ class GenerativeQueryTest(_base.MappedTest): testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')): return - query = create_session().query(Foo) - assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435 + query = create_session().query(func.sum(foo.c.bar)) + assert query.filter(foo.c.bar<30).one() == (435,) @testing.fails_on('firebird', 'mssql') @testing.resolve_artifact_names def test_aggregate_2(self): - query = create_session().query(Foo) - avg = query.filter(foo.c.bar < 30).avg(foo.c.bar) + query = create_session().query(func.avg(foo.c.bar)) + avg = query.filter(foo.c.bar < 30).one()[0] eq_(round(avg, 1), 14.5) @testing.resolve_artifact_names diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index 1cf09582e0..c706acf6e1 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -510,7 +510,7 @@ def make_test(select_type): self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1) - + def test_mixed_entities(self): sess = create_session() @@ -525,6 +525,41 @@ def make_test(select_type): [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), u'Elbonia, Inc.')] ) + + + self.assertEquals( + sess.query(Manager.name).all(), + [('pointy haired boss', ), ('dogbert',)] + ) + + self.assertEquals( + sess.query(Manager.name + " foo").all(), + [('pointy haired boss foo', ), ('dogbert foo',)] + ) + + + self.assertEquals( + sess.query(Engineer.name, Engineer.primary_language).all(), + [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')] + ) + + self.assertEquals( + sess.query(Boss.name, Boss.golf_swing).all(), + [(u'pointy haired boss', u'fore')] + ) + + # TODO: I think raise error on these for now. different inheritance/loading schemes have different + # results here, all incorrect + # + # self.assertEquals( + # sess.query(Person.name, Engineer.primary_language).all(), + # [] + # ) + + # self.assertEquals( + # sess.query(Person.name, Engineer.primary_language, Manager.manager_name).all(), + # [] + # ) self.assertEquals( sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py index 58e7ad82a6..17497177e4 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/single.py @@ -3,8 +3,9 @@ from sqlalchemy import * from sqlalchemy.orm import * from testlib import * from testlib.fixtures import Base +from orm._base import MappedTest, ComparableEntity -class SingleInheritanceTest(ORMTest): +class SingleInheritanceTest(MappedTest): def define_tables(self, metadata): global employees_table employees_table = Table('employees', metadata, @@ -14,9 +15,9 @@ class SingleInheritanceTest(ORMTest): Column('engineer_info', String(50)), Column('type', String(20)) ) - - def test_single_inheritance(self): - class Employee(Base): + + def setup_classes(self): + class Employee(ComparableEntity): pass class Manager(Employee): pass @@ -25,19 +26,22 @@ class SingleInheritanceTest(ORMTest): class JuniorEngineer(Engineer): pass + @testing.resolve_artifact_names + def setup_mappers(self): mapper(Employee, employees_table, polymorphic_on=employees_table.c.type) mapper(Manager, inherits=Employee, polymorphic_identity='manager') mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer') + + @testing.resolve_artifact_names + def test_single_inheritance(self): session = create_session() m1 = Manager(name='Tom', manager_data='knows how to manage things') e1 = Engineer(name='Kurt', engineer_info='knows how to hack') e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') - session.save(m1) - session.save(e1) - session.save(e2) + session.add_all([m1, e1, e2]) session.flush() assert session.query(Employee).all() == [m1, e1, e2] @@ -48,6 +52,86 @@ class SingleInheritanceTest(ORMTest): m1 = session.query(Manager).one() session.expire(m1, ['manager_data']) self.assertEquals(m1.manager_data, "knows how to manage things") + + @testing.resolve_artifact_names + def test_multi_qualification(self): + session = create_session() + + m1 = Manager(name='Tom', manager_data='knows how to manage things') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + + session.add_all([m1, e1, e2]) + session.flush() + + ealias = aliased(Engineer) + self.assertEquals( + session.query(Manager, ealias).all(), + [(m1, e1), (m1, e2)] + ) + + self.assertEquals( + session.query(Manager.name).all(), + [("Tom",)] + ) + + self.assertEquals( + session.query(Manager.name, ealias.name).all(), + [("Tom", "Kurt"), ("Tom", "Ed")] + ) + + self.assertEquals( + session.query(func.upper(Manager.name), func.upper(ealias.name)).all(), + [("TOM", "KURT"), ("TOM", "ED")] + ) + + self.assertEquals( + session.query(Manager).add_entity(ealias).all(), + [(m1, e1), (m1, e2)] + ) + + self.assertEquals( + session.query(Manager.name).add_column(ealias.name).all(), + [("Tom", "Kurt"), ("Tom", "Ed")] + ) + + # TODO: I think raise error on this for now + # self.assertEquals( + # session.query(Employee.name, Manager.manager_data, Engineer.engineer_info).all(), + # [] + # ) + + @testing.resolve_artifact_names + def test_select_from(self): + sess = create_session() + m1 = Manager(name='Tom', manager_data='data1') + m2 = Manager(name='Tom2', manager_data='data2') + e1 = Engineer(name='Kurt', engineer_info='knows how to hack') + e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed') + sess.add_all([m1, m2, e1, e2]) + sess.flush() + + self.assertEquals( + sess.query(Manager).select_from(employees_table.select().limit(10)).all(), + [m1, m2] + ) + + @testing.resolve_artifact_names + def test_count(self): + sess = create_session() + m1 = Manager(name='Tom', manager_data='data1') + m2 = Manager(name='Tom2', manager_data='data2') + e1 = Engineer(name='Kurt', engineer_info='data3') + e2 = JuniorEngineer(name='marvin', engineer_info='data4') + sess.add_all([m1, m2, e1, e2]) + sess.flush() + + self.assertEquals(sess.query(Manager).count(), 2) + self.assertEquals(sess.query(Engineer).count(), 2) + self.assertEquals(sess.query(Employee).count(), 4) + + self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) + self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) class SingleOnJoinedTest(ORMTest): def define_tables(self, metadata): diff --git a/test/orm/query.py b/test/orm/query.py index f816419187..fe70dec528 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -542,7 +542,7 @@ class AggregateTest(QueryTest): def test_sum(self): sess = create_session() orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) - assert orders.sum(Order.user_id * Order.address_id) == 79 + self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) def test_apply(self): sess = create_session()