return a.has(b, **kwargs)
has_op = staticmethod(has_op)
- def __init__(self, prop):
+ def __init__(self, prop, mapper):
self.prop = self.property = prop
-
+ self.mapper = mapper
+
def of_type_op(a, class_):
return a.of_type(class_)
of_type_op = staticmethod(of_type_op)
def __init__(self, parent):
self.parent_property = parent
self.is_class_level = False
-
- def init(self):
self.parent = self.parent_property.parent
self.key = self.parent_property.key
+ def init(self):
+ raise NotImplementedError("LoaderStrategy")
+
def init_class_attribute(self):
pass
self.columns = [expression._labeled(c) for c in columns]
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
- self.comparator = ColumnProperty.ColumnComparator(self)
+ self.comparator_factory = ColumnProperty.ColumnComparator
util.set_creation_order(self)
if self.deferred:
self.strategy_class = strategies.DeferredColumnLoader
class ColumnComparator(PropComparator):
def __clause_element__(self):
- return self.prop.columns[0]._annotate({"parententity": self.prop.parent})
+ return self.prop.columns[0]._annotate({"parententity": self.mapper})
__clause_element__ = util.cache_decorator(__clause_element__)
def operate(self, op, *other, **kwargs):
def __init__(self, class_, *columns, **kwargs):
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
- self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+ self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator)
self.strategy_class = strategies.CompositeColumnLoader
def do_init(self):
def do_init(self):
class_ = self.parent.class_
- def comparator():
- return self.parent._get_property(self.key, resolve_synonyms=True).comparator
+
self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__))
if self.descriptor is None:
class SynonymProp(object):
return s
return getattr(obj, self.name)
self.descriptor = SynonymProp()
- sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent)
+
+ def comparator_callable(prop, mapper):
+ def comparator():
+ prop = self.parent._get_property(self.key, resolve_synonyms=True)
+ return prop.comparator_factory(prop, mapper)
+ return comparator
+
+ strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, comparator_callable, proxy_property=self.descriptor)
def merge(self, session, source, dest, _recursive):
pass
def __init__(self, comparator_factory, descriptor=None):
self.descriptor = descriptor
- self.comparator = comparator_factory(self)
+ self.comparator_factory = comparator_factory
util.set_creation_order(self)
def do_init(self):
"""Set up a proxy to the unmanaged descriptor."""
- class_ = self.parent.class_
- # refactor me
- sessionlib.register_attribute(class_, self.key, uselist=False,
- proxy_property=self.descriptor,
- useobject=False,
- comparator=self.comparator)
+ strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, self.comparator_factory, proxy_property=self.descriptor)
def setup(self, context, entity, path, adapter, **kwargs):
pass
self.passive_updates = passive_updates
self.remote_side = remote_side
self.enable_typechecks = enable_typechecks
- self.comparator = PropertyLoader.Comparator(self)
+ self.comparator = PropertyLoader.Comparator(self, None)
self.join_depth = join_depth
self.local_remote_pairs = _local_remote_pairs
self.__join_cache = {}
+ self.comparator_factory = PropertyLoader.Comparator
util.set_creation_order(self)
if strategy_class:
self._is_backref = _is_backref
class Comparator(PropComparator):
- def __init__(self, prop, of_type=None):
+ def __init__(self, prop, mapper, of_type=None):
self.prop = self.property = prop
+ self.mapper = mapper
if of_type:
self._of_type = _class_to_mapper(of_type)
return op(self, *other, **kwargs)
def of_type(self, cls):
- return PropertyLoader.Comparator(self.prop, cls)
+ return PropertyLoader.Comparator(self.prop, self.mapper, cls)
def __eq__(self, other):
if other is None:
self.__setup_aliasizers(self._entities)
def __setup_aliasizers(self, entities):
- d = {}
+ if hasattr(self, '_mapper_adapter_map'):
+ # usually safe to share a single map, but copying to prevent
+ # subtle leaks if end-user is reusing base query with arbitrary
+ # number of aliased() objects
+ self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
+ else:
+ self._mapper_adapter_map = d = {}
+
for ent in entities:
for entity in ent.entities:
if entity not in d:
d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic)
ent.setup_entity(entity, *d[entity])
-
+
def __mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers:
for m in m2.iterate_to_root():
return self.filter(sql.and_(*clauses))
- def min(self, col):
- """Execute the SQL ``min()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.min)
-
- def max(self, col):
- """Execute the SQL ``max()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.max)
-
- def sum(self, col):
- """Execute the SQL ``sum()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.sum)
-
- def avg(self, col):
- """Execute the SQL ``avg()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.avg)
-
def order_by(self, *criterion):
"""apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
_should_nest_selectable = property(_should_nest_selectable)
def count(self):
- """Apply this query's criterion to a SELECT COUNT statement.
-
- this is the purely generative version which will become
- the public method in version 0.5.
-
- """
- return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key))
+ """Apply this query's criterion to a SELECT COUNT statement."""
+
+ return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key))
def _col_aggregate(self, col, func, nested_cols=None):
- whereclause = self._criterion
-
context = QueryContext(self)
+
+ self._adjust_for_single_inheritance(context)
+
+ whereclause = context.whereclause
+
from_obj = self.__mapper_zero_from_obj()
if self._should_nest_selectable:
froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used
else:
froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
-
+
+ self._adjust_for_single_inheritance(context)
+
if eager_joins and self._should_nest_selectable:
# for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select,
# then append eager joins onto that
context.order_by = None
order_by_col_expr = []
- inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args)
+ inner = sql.select(
+ context.primary_columns + order_by_col_expr,
+ context.whereclause,
+ from_obj=froms,
+ use_labels=labels,
+ correlate=False,
+ order_by=context.order_by,
+ **self._select_args
+ )
if self._correlate:
inner = inner.correlate(*self._correlate)
froms += context.eager_joins.values()
- statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args)
+ statement = sql.select(
+ context.primary_columns + context.secondary_columns,
+ context.whereclause,
+ from_obj=froms,
+ use_labels=labels,
+ for_update=for_update,
+ correlate=False,
+ order_by=context.order_by,
+ **self._select_args
+ )
+
if self._correlate:
statement = statement.correlate(*self._correlate)
return context
+ def _adjust_for_single_inheritance(self, context):
+ """Apply single-table-inheritance filtering.
+
+ For all distinct single-table-inheritance mappers represented in the columns
+ clause of this query, add criterion to the WHERE clause of the given QueryContext
+ such that only the appropriate subtypes are selected from the total results.
+
+ """
+ for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems():
+ if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None:
+ crit = mapper.polymorphic_on.in_([m.polymorphic_identity for m in mapper.polymorphic_iterator()])
+ if adapter:
+ crit = adapter.traverse(crit)
+ crit = self._adapt_clause(crit, False, False)
+ context.whereclause = sql.and_(context.whereclause, crit)
+
def __log_debug(self, msg):
self.logger.debug(msg)
self.entities = [entity]
self.entity_zero = entity
self.entity_name = entity_name
-
+
def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
self.mapper = mapper
self.extension = self.mapper.extension
return main, entname
def setup_context(self, query, context):
- # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
- # that we only load the appropriate types
- if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
- context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
+ adapter = self._get_entity_clauses(query, context)
context.froms.append(self.selectable)
- adapter = self._get_entity_clauses(query, context)
-
if context.order_by is False and self.mapper.order_by:
context.order_by = self.mapper.order_by
from sqlalchemy.orm import util as mapperutil
-class ColumnLoader(LoaderStrategy):
- """Default column loader."""
+class DefaultColumnLoader(LoaderStrategy):
+ def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None):
+ self.logger.info("%s register managed attribute" % self)
+
+ for mapper in self.parent.polymorphic_iterator():
+ if mapper is self.parent or not mapper.concrete:
+ sessionlib.register_attribute(
+ mapper.class_,
+ self.key,
+ uselist=False,
+ useobject=False,
+ copy_function=copy_function,
+ compare_function=compare_function,
+ mutable_scalars=mutable_scalars,
+ comparator=comparator_factory(self.parent_property, mapper),
+ parententity=mapper,
+ callable_=callable_,
+ proxy_property=proxy_property
+ )
+
+DefaultColumnLoader.logger = log.class_logger(DefaultColumnLoader)
+
+class ColumnLoader(DefaultColumnLoader):
def init(self):
- super(ColumnLoader, self).init()
self.columns = self.parent_property.columns
self._should_log_debug = log.is_debug_enabled(self.logger)
self.is_composite = hasattr(self.parent_property, 'composite_class')
def init_class_attribute(self):
self.is_class_level = True
- self.logger.info("%s register managed attribute" % self)
coltype = self.columns[0].type
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
+
+ self._register_attribute(
+ coltype.compare_values,
+ coltype.copy_value,
+ self.columns[0].type.is_mutable(),
+ self.parent_property.comparator_factory
+ )
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
key, col = self.key, self.columns[0]
return False
else:
return True
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent)
+
+ self._register_attribute(
+ compare,
+ copy,
+ True,
+ self.parent_property.comparator_factory
+ )
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class
CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader)
-class DeferredColumnLoader(LoaderStrategy):
+class DeferredColumnLoader(DefaultColumnLoader):
"""Deferred column loader, a per-column or per-column-group lazy loader."""
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
return (new_execute, None)
def init(self):
- super(DeferredColumnLoader, self).init()
if hasattr(self.parent_property, 'composite_class'):
raise NotImplementedError("Deferred loading for composite types not implemented yet")
self.columns = self.parent_property.columns
def init_class_attribute(self):
self.is_class_level = True
- self.logger.info("%s register managed attribute" % self)
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
+ self._register_attribute(
+ self.columns[0].type.compare_values,
+ self.columns[0].type.copy_value,
+ self.columns[0].type.is_mutable(),
+ self.parent_property.comparator_factory,
+ callable_=self.class_level_loader,
+ )
def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
if \
class AbstractRelationLoader(LoaderStrategy):
def init(self):
- super(AbstractRelationLoader, self).init()
for attr in ['mapper', 'target', 'table', 'uselist']:
setattr(self, attr, getattr(self.parent_property, attr))
self._should_log_debug = log.is_debug_enabled(self.logger)
else:
state.initialize(self.key)
- def _register_attribute(self, class_, callable_=None, **kwargs):
+ def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs):
self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
if self.parent_property.backref:
else:
attribute_ext = None
- sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs)
+ sessionlib.register_attribute(
+ class_,
+ self.key,
+ uselist=self.uselist,
+ useobject=True,
+ extension=attribute_ext,
+ cascade=self.parent_property.cascade,
+ trackparent=True,
+ typecallable=self.parent_property.collection_class,
+ callable_=callable_,
+ comparator=self.parent_property.comparator,
+ parententity=self.parent,
+ impl_class=impl_class,
+ **kwargs
+ )
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
import testenv; testenv.configure_for_tests()
from testlib import testing, sa
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData
+from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func
from sqlalchemy.orm import mapper, relation, create_session
from testlib.testing import eq_
from testlib.compat import set
sess = create_session()
query = sess.query(Foo)
assert query.count() == 100
- assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
- assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
+ assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,)
+
+ assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,)
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')):
return
- query = create_session().query(Foo)
- assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
+ query = create_session().query(func.sum(foo.c.bar))
+ assert query.filter(foo.c.bar<30).one() == (435,)
@testing.fails_on('firebird', 'mssql')
@testing.resolve_artifact_names
def test_aggregate_2(self):
- query = create_session().query(Foo)
- avg = query.filter(foo.c.bar < 30).avg(foo.c.bar)
+ query = create_session().query(func.avg(foo.c.bar))
+ avg = query.filter(foo.c.bar < 30).one()[0]
eq_(round(avg, 1), 14.5)
@testing.resolve_artifact_names
self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
-
+
def test_mixed_entities(self):
sess = create_session()
[(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
u'Elbonia, Inc.')]
)
+
+
+ self.assertEquals(
+ sess.query(Manager.name).all(),
+ [('pointy haired boss', ), ('dogbert',)]
+ )
+
+ self.assertEquals(
+ sess.query(Manager.name + " foo").all(),
+ [('pointy haired boss foo', ), ('dogbert foo',)]
+ )
+
+
+ self.assertEquals(
+ sess.query(Engineer.name, Engineer.primary_language).all(),
+ [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')]
+ )
+
+ self.assertEquals(
+ sess.query(Boss.name, Boss.golf_swing).all(),
+ [(u'pointy haired boss', u'fore')]
+ )
+
+ # TODO: I think raise error on these for now. different inheritance/loading schemes have different
+ # results here, all incorrect
+ #
+ # self.assertEquals(
+ # sess.query(Person.name, Engineer.primary_language).all(),
+ # []
+ # )
+
+ # self.assertEquals(
+ # sess.query(Person.name, Engineer.primary_language, Manager.manager_name).all(),
+ # []
+ # )
self.assertEquals(
sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import Base
+from orm._base import MappedTest, ComparableEntity
-class SingleInheritanceTest(ORMTest):
+class SingleInheritanceTest(MappedTest):
def define_tables(self, metadata):
global employees_table
employees_table = Table('employees', metadata,
Column('engineer_info', String(50)),
Column('type', String(20))
)
-
- def test_single_inheritance(self):
- class Employee(Base):
+
+ def setup_classes(self):
+ class Employee(ComparableEntity):
pass
class Manager(Employee):
pass
class JuniorEngineer(Engineer):
pass
+ @testing.resolve_artifact_names
+ def setup_mappers(self):
mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
mapper(Manager, inherits=Employee, polymorphic_identity='manager')
mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
+
+ @testing.resolve_artifact_names
+ def test_single_inheritance(self):
session = create_session()
m1 = Manager(name='Tom', manager_data='knows how to manage things')
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
- session.save(m1)
- session.save(e1)
- session.save(e2)
+ session.add_all([m1, e1, e2])
session.flush()
assert session.query(Employee).all() == [m1, e1, e2]
m1 = session.query(Manager).one()
session.expire(m1, ['manager_data'])
self.assertEquals(m1.manager_data, "knows how to manage things")
+
+ @testing.resolve_artifact_names
+ def test_multi_qualification(self):
+ session = create_session()
+
+ m1 = Manager(name='Tom', manager_data='knows how to manage things')
+ e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
+ e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
+
+ session.add_all([m1, e1, e2])
+ session.flush()
+
+ ealias = aliased(Engineer)
+ self.assertEquals(
+ session.query(Manager, ealias).all(),
+ [(m1, e1), (m1, e2)]
+ )
+
+ self.assertEquals(
+ session.query(Manager.name).all(),
+ [("Tom",)]
+ )
+
+ self.assertEquals(
+ session.query(Manager.name, ealias.name).all(),
+ [("Tom", "Kurt"), ("Tom", "Ed")]
+ )
+
+ self.assertEquals(
+ session.query(func.upper(Manager.name), func.upper(ealias.name)).all(),
+ [("TOM", "KURT"), ("TOM", "ED")]
+ )
+
+ self.assertEquals(
+ session.query(Manager).add_entity(ealias).all(),
+ [(m1, e1), (m1, e2)]
+ )
+
+ self.assertEquals(
+ session.query(Manager.name).add_column(ealias.name).all(),
+ [("Tom", "Kurt"), ("Tom", "Ed")]
+ )
+
+ # TODO: I think raise error on this for now
+ # self.assertEquals(
+ # session.query(Employee.name, Manager.manager_data, Engineer.engineer_info).all(),
+ # []
+ # )
+
+ @testing.resolve_artifact_names
+ def test_select_from(self):
+ sess = create_session()
+ m1 = Manager(name='Tom', manager_data='data1')
+ m2 = Manager(name='Tom2', manager_data='data2')
+ e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
+ e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
+ sess.add_all([m1, m2, e1, e2])
+ sess.flush()
+
+ self.assertEquals(
+ sess.query(Manager).select_from(employees_table.select().limit(10)).all(),
+ [m1, m2]
+ )
+
+ @testing.resolve_artifact_names
+ def test_count(self):
+ sess = create_session()
+ m1 = Manager(name='Tom', manager_data='data1')
+ m2 = Manager(name='Tom2', manager_data='data2')
+ e1 = Engineer(name='Kurt', engineer_info='data3')
+ e2 = JuniorEngineer(name='marvin', engineer_info='data4')
+ sess.add_all([m1, m2, e1, e2])
+ sess.flush()
+
+ self.assertEquals(sess.query(Manager).count(), 2)
+ self.assertEquals(sess.query(Engineer).count(), 2)
+ self.assertEquals(sess.query(Employee).count(), 4)
+
+ self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
+ self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
class SingleOnJoinedTest(ORMTest):
def define_tables(self, metadata):
def test_sum(self):
sess = create_session()
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
- assert orders.sum(Order.user_id * Order.address_id) == 79
+ self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
def test_apply(self):
sess = create_session()