- Added a more aggressive check for "uncompiled mappers",
helps particularly with declarative layer [ticket:995]
+ - The methodology behind "primaryjoin"/"secondaryjoin" has
+ been refactored. Behavior should be slightly more
+ intelligent, primarily in terms of error messages which
+ have been pared down to be more readable. In a slight
+ number of scenarios it can better resolve the correct
+ foreign key than before.
+
- Added comparable_property(), adds query Comparator behavior
to regular, unmanaged Python properties
"""
from sqlalchemy.orm import sync
-from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY
from sqlalchemy import sql, util, exceptions
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
def create_dependency_processor(prop):
self.passive_updates = prop.passive_updates
self.enable_typechecks = prop.enable_typechecks
self.key = prop.key
-
- self._compile_synchronizers()
+ if not self.prop.synchronize_pairs:
+ raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop)
def _get_instrumented_attribute(self):
"""Return the ``InstrumentedAttribute`` handled by this
raise NotImplementedError()
- def _compile_synchronizers(self):
- """Assemble a list of *synchronization rules*.
-
- These are fired to populate attributes from one side
- of a relation to another.
- """
-
- self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
- if self.direction == sync.MANYTOMANY:
- self.syncrules.compile(self.prop.primaryjoin, issecondary=False, foreign_keys=self.foreign_keys)
- self.syncrules.compile(self.prop.secondaryjoin, issecondary=True, foreign_keys=self.foreign_keys)
- else:
- self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys)
-
def _conditional_post_update(self, state, uowcommit, related):
"""Execute a post_update call.
if state is not None and self.post_update:
for x in related:
if x is not None:
- uowcommit.register_object(state, postupdate=True, post_update_cols=self.syncrules.dest_columns())
+ uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs])
break
def _pks_changed(self, uowcommit, state):
- return self.syncrules.source_changes(uowcommit, state)
+ raise NotImplementedError()
def __str__(self):
return "%s(%s)" % (self.__class__.__name__, str(self.prop))
if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
return
self._verify_canload(child)
- self.syncrules.execute(source, dest, source, child, clearkeys)
+ if clearkeys:
+ sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
+ else:
+ sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs)
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
class DetectKeySwitch(DependencyProcessor):
"""a special DP that works for many-to-one relations, fires off for
elem.dict[self.key]._state in switchers
]:
uowcommit.register_object(s, listonly=self.passive_updates)
- self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
+ sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs)
+ #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
class ManyToOneDP(DependencyProcessor):
def __init__(self, prop):
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
- source = child
- dest = state
- if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
+ if state is None or (not self.post_update and uowcommit.is_deleted(state)):
return
- self._verify_canload(child)
- self.syncrules.execute(source, dest, dest, child, clearkeys)
+
+ if clearkeys or child is None:
+ sync.clear(state, self.parent, self.prop.synchronize_pairs)
+ else:
+ self._verify_canload(child)
+ sync.populate(child, self.mapper, state, self.parent, self.prop.synchronize_pairs)
class ManyToManyDP(DependencyProcessor):
def register_dependencies(self, uowcommit):
if not self.passive_updates and unchanged and self._pks_changed(uowcommit, state):
for child in unchanged:
associationrow = {}
- self.syncrules.update(associationrow, state, child, "old_")
+ sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs)
+ sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs)
+
+ #self.syncrules.update(associationrow, state, child, "old_")
secondary_update.append(associationrow)
if secondary_delete:
if associationrow is None:
return
self._verify_canload(child)
- self.syncrules.execute(None, associationrow, state, child, clearkeys)
+
+ sync.populate_dict(state, self.parent, associationrow, self.prop.synchronize_pairs)
+ sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs)
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
class AssociationDP(OneToManyDP):
def __init__(self, *args, **kwargs):
EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE')
EXT_STOP = util.symbol('EXT_STOP')
+ONETOMANY = util.symbol('ONETOMANY')
+MANYTOONE = util.symbol('MANYTOONE')
+MANYTOMANY = util.symbol('MANYTOMANY')
+
class MapperExtension(object):
"""Base implementation for customizing Mapper behavior.
self._dependency_processors = []
self._clause_adapter = None
self._requires_row_aliasing = False
-
+ self.__inherits_equated_pairs = None
+
if not issubclass(class_, object):
raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
self.__should_log_info = logging.is_info_enabled(self.logger)
self.__should_log_debug = logging.is_debug_enabled(self.logger)
- self._compile_class()
- self._compile_inheritance()
- self._compile_extensions()
- self._compile_properties()
- self._compile_pks()
+ self.__compile_class()
+ self.__compile_inheritance()
+ self.__compile_extensions()
+ self.__compile_properties()
+ self.__compile_pks()
global __new_mappers
__new_mappers = True
self.__log("constructed")
to execute once all mappers have been constructed.
"""
- self.__log("_initialize_properties() started")
+ self.__log("__initialize_properties() started")
l = [(key, prop) for key, prop in self.__props.iteritems()]
for key, prop in l:
self.__log("initialize prop " + key)
if getattr(prop, 'key', None) is None:
prop.init(key, self)
- self.__log("_initialize_properties() complete")
+ self.__log("__initialize_properties() complete")
self.__props_init = True
- def _compile_extensions(self):
+ def __compile_extensions(self):
"""Go through the global_extensions list as well as the list
of ``MapperExtensions`` specified for this ``Mapper`` and
creates a linked list of those extensions.
for ext in extlist:
self.extension.append(ext)
- def _compile_inheritance(self):
+ def __compile_inheritance(self):
"""Configure settings related to inherting and/or inherited mappers being present."""
if self.inherits:
self.single = True
if not self.local_table is self.inherits.local_table:
if self.concrete:
- self._synchronizer = None
self.mapped_table = self.local_table
for mapper in self.iterate_to_root():
if mapper.polymorphic_on:
# stuff we dont want (allows test/inheritance.InheritTest4 to pass)
self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause
self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
- # generate sync rules. similarly to creating the on clause, specify a
- # stricter set of tables to create "sync rules" by,based on the immediate
- # inherited table, rather than all inherited tables
- self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
- if self.inherit_foreign_keys:
- fks = util.Set(self.inherit_foreign_keys)
- else:
- fks = None
- self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks)
+
+ fks = util.to_set(self.inherit_foreign_keys)
+ self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks)
else:
- self._synchronizer = None
self.mapped_table = self.local_table
if self.polymorphic_identity is not None:
self.inherits.polymorphic_map[self.polymorphic_identity] = self
else:
self._all_tables = util.Set()
self.base_mapper = self
- self._synchronizer = None
self.mapped_table = self.local_table
if self.polymorphic_identity:
if self.polymorphic_on is None:
if self.mapped_table is None:
raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
- def _compile_pks(self):
+ def __compile_pks(self):
self.tables = sqlutil.find_tables(self.mapped_table)
return getattr(getattr(cls, clskey), key)
- def _compile_properties(self):
+ def __compile_properties(self):
# object attribute names mapped to MapperProperty objects
self.__props = util.OrderedDict()
for mapper in self._inheriting_mappers:
mapper._adapt_inherited_property(key, prop)
- def _compile_class(self):
+ def __compile_class(self):
"""If this mapper is to be a primary mapper (i.e. the
non_primary flag is not set), associate this Mapper with the
given class_ and entity name.
# TODO: this fires off more than needed, try to organize syncrules
# per table
for m in util.reversed(list(mapper.iterate_to_root())):
- if m._synchronizer:
- m._synchronizer.execute(state, state)
+ if m.__inherits_equated_pairs:
+ m._synchronize_inherited(state)
# testlib.pragma exempt:__hash__
inserted_objects.add((state, connection))
if 'after_update' in mapper.extension.methods:
mapper.extension.after_update(mapper, connection, state.obj())
+ def _synchronize_inherited(self, state):
+ sync.populate(state, self, state, self, self.__inherits_equated_pairs)
+
def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
"""After an ``INSERT`` or ``UPDATE``, assemble newly generated
values on an instance. For columns which are marked as being generated
"""
from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql.util import ClauseAdapter
+from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns
from sqlalchemy.sql import visitors, operators, ColumnElement
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm.mapper import _class_to_mapper
from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
-from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
from sqlalchemy.exceptions import ArgumentError
-
__all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty',
'ComparableProperty', 'PropertyLoader', 'BackRef')
def __eq__(self, other):
if other is None:
- if self.prop.direction == sync.ONETOMANY:
+ if self.prop.direction == ONETOMANY:
return ~sql.exists([1], self.prop.primaryjoin)
else:
return self.prop._optimized_compare(None)
def __ne__(self, other):
if other is None:
- if self.prop.direction == sync.MANYTOONE:
+ if self.prop.direction == MANYTOONE:
return sql.or_(*[x!=None for x in self.prop.foreign_keys])
elif self.prop.uselist:
return self.any()
return self.argument.class_
def do_init(self):
- self._determine_targets()
- self._determine_joins()
- self._determine_fks()
- self._determine_direction()
- self._determine_remote_side()
+ self.__determine_targets()
+ self.__determine_joins()
+ self.__determine_fks()
+ self.__determine_direction()
+ self.__determine_remote_side()
self._post_init()
- def _determine_targets(self):
+ def __determine_targets(self):
if isinstance(self.argument, type):
self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)
elif isinstance(self.argument, mapper.Mapper):
if self.cascade.delete_orphan:
if self.parent.class_ is self.mapper.class_:
- raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade rule on a self-referential relationship. You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
+ raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
+ "rule on a self-referential relationship. "
+ "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
- def _determine_joins(self):
+ def __determine_joins(self):
if self.secondaryjoin is not None and self.secondary is None:
raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
# if join conditions were not specified, figure them out based on foreign keys
if self.primaryjoin is None:
self.primaryjoin = _search_for_join(self.parent, self.target).onclause
except exceptions.ArgumentError, e:
- raise exceptions.ArgumentError("""Error determining primary and/or secondary join for relationship '%s'. If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions. Nested error is \"%s\"""" % (str(self), str(e)))
+ raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s. "
+ "Specify a 'primaryjoin' expression. If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self))
- def _col_is_part_of_mappings(self, column):
+ def __col_is_part_of_mappings(self, column):
if self.secondary is None:
return self.parent.mapped_table.c.contains_column(column) or \
self.target.c.contains_column(column)
self.target.c.contains_column(column) or \
self.secondary.c.contains_column(column) is not None
- def _determine_fks(self):
+ def __determine_fks(self):
if self._legacy_foreignkey and not self._refers_to_parent_table():
self.foreign_keys = self._legacy_foreignkey
- self._opposite_side = util.Set()
+ arg_foreign_keys = self.foreign_keys
+
+ eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
+ eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+
+ if not eq_pairs:
+ if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+ raise exceptions.ArgumentError("Could not locate any equated column pairs for primaryjoin condition '%s' on relation %s. "
+ "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.primaryjoin, self)
+ )
+ else:
+ raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+ "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
+
+ self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
+ self._opposite_side = util.OrderedSet([l for l, r in eq_pairs])
+ self.synchronize_pairs = eq_pairs
+
+ if self.secondaryjoin:
+ sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys)
+ sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+
+ if not sq_pairs:
+ if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+ raise exceptions.ArgumentError("Could not locate any equated column pairs for secondaryjoin condition '%s' on relation %s. "
+ "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.secondaryjoin, self)
+ )
+ else:
+ raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
+ "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self))
- if self.foreign_keys:
- def visit_binary(binary):
- if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
- return
- if binary.left in self.foreign_keys:
- self._opposite_side.add(binary.right)
- if binary.right in self.foreign_keys:
- self._opposite_side.add(binary.left)
+ self.foreign_keys.update([r for l, r in sq_pairs])
+ self._opposite_side.update([l for l, r in sq_pairs])
+ self.secondary_synchronize_pairs = sq_pairs
else:
- self.foreign_keys = util.Set()
- def visit_binary(binary):
- if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
- return
-
- # this check is for when the user put the "view_only" flag on and has tables that have nothing
- # to do with the relationship's parent/child mappings in the join conditions. we dont want cols
- # or clauses related to those external tables dealt with. see orm.relationships.ViewOnlyTest
- if not self._col_is_part_of_mappings(binary.left) or not self._col_is_part_of_mappings(binary.right):
- return
-
- for f in binary.left.foreign_keys:
- if f.references(binary.right.table):
- self.foreign_keys.add(binary.left)
- self._opposite_side.add(binary.right)
- for f in binary.right.foreign_keys:
- if f.references(binary.left.table):
- self.foreign_keys.add(binary.right)
- self._opposite_side.add(binary.left)
-
- visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
-
- if not self.foreign_keys:
- raise exceptions.ArgumentError(
- "Can't locate any foreign key columns in primary join "
- "condition '%s' for relationship '%s'. Specify "
- "'foreign_keys' argument to indicate which columns in "
- "the join condition are foreign." %(str(self.primaryjoin), str(self)))
-
- if self.secondaryjoin is not None:
- visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
+ self.secondary_synchronize_pairs = None
+
+ def equated_pairs(self):
+ return zip(self.local_side, self.remote_side)
+ equated_pairs = property(equated_pairs)
+
+ def __determine_remote_side(self):
+ if self.remote_side:
+ if self.direction is MANYTOONE:
+ eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True)
+ else:
+ eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True)
+ if self.secondaryjoin:
+ sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True)
+ sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+ eq_pairs += sq_pairs
+ else:
+ eq_pairs = zip(self._opposite_side, self.foreign_keys)
- def _determine_direction(self):
+ if self.direction is MANYTOONE:
+ self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+ else:
+ self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+
+ def __determine_direction(self):
"""Determine our *direction*, i.e. do we represent one to
many, many to many, etc.
"""
if self.secondaryjoin is not None:
- self.direction = sync.MANYTOMANY
+ self.direction = MANYTOMANY
elif self._refers_to_parent_table():
# for a self referential mapper, if the "foreignkey" is a single or composite primary key,
# then we are "many to one", since the remote site of the relationship identifies a singular entity.
if self._legacy_foreignkey:
for f in self._legacy_foreignkey:
if not f.primary_key:
- self.direction = sync.ONETOMANY
+ self.direction = ONETOMANY
else:
- self.direction = sync.MANYTOONE
+ self.direction = MANYTOONE
elif self.remote_side:
for f in self.foreign_keys:
if f in self.remote_side:
- self.direction = sync.ONETOMANY
+ self.direction = ONETOMANY
return
else:
- self.direction = sync.MANYTOONE
+ self.direction = MANYTOONE
else:
- self.direction = sync.ONETOMANY
+ self.direction = ONETOMANY
else:
for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]:
onetomany = [c for c in self.foreign_keys if mappedtable.c.contains_column(c)]
elif onetomany and manytoone:
continue
elif onetomany:
- self.direction = sync.ONETOMANY
+ self.direction = ONETOMANY
break
elif manytoone:
- self.direction = sync.MANYTOONE
+ self.direction = MANYTOONE
break
else:
raise exceptions.ArgumentError(
"the child's mapped tables. Specify 'foreign_keys' "
"argument." % (str(self)))
- def _determine_remote_side(self):
- if not self.remote_side:
- if self.direction is sync.MANYTOONE:
- self.remote_side = util.Set(self._opposite_side)
- elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
- self.remote_side = util.Set(self.foreign_keys)
-
- self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
-
def _post_init(self):
if logging.is_info_enabled(self.logger):
self.logger.info(str(self) + " setup primary join " + str(self.primaryjoin))
self.logger.info(str(self) + " setup secondary join " + str(self.secondaryjoin))
- self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys]))
- self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side]))
- self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many")))
+ self.logger.info(str(self) + " synchronize pairs " + ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs]))
+ self.logger.info(str(self) + " equated pairs " + ",".join(["(%s == %s)" % (l, r) for l, r in self.equated_pairs]))
+ self.logger.info(str(self) + " relation direction " + (self.direction is ONETOMANY and "one-to-many" or (self.direction is MANYTOONE and "many-to-one" or "many-to-many")))
- if self.uselist is None and self.direction is sync.MANYTOONE:
+ if self.uselist is None and self.direction is MANYTOONE:
self.uselist = False
if self.uselist is None:
primaryjoin = self.primaryjoin
if fromselectable is not frommapper.local_table:
- if self.direction is sync.ONETOMANY:
+ if self.direction is ONETOMANY:
primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
- elif self.direction is sync.MANYTOONE:
+ elif self.direction is MANYTOONE:
primaryjoin = ClauseAdapter(fromselectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
elif self.secondaryjoin:
primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
class LazyLoader(AbstractRelationLoader):
def init(self):
super(LazyLoader, self).init()
- (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self.parent_property)
+ (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property)
- self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
+ self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere))
# determine if our "lazywhere" clause is the same as the mapper's
# get() clause. then we can just use mapper.get()
#from sqlalchemy.orm import query
- self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.lazywhere)
+ self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
if self.use_get:
self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
return self._lazy_none_clause(reverse_direction)
if not reverse_direction:
- (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
+ (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
else:
- (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
- bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+ (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
def visit_bindparam(bindparam):
mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
def _lazy_none_clause(self, reverse_direction=False):
if not reverse_direction:
- (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
+ (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
else:
- (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
- bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+ (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
def visit_binary(binary):
mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
instance._state.reset(self.key)
return (new_execute, None, None)
- def _create_lazy_clause(cls, prop, reverse_direction=False):
- (primaryjoin, secondaryjoin, remote_side) = (prop.primaryjoin, prop.secondaryjoin, prop.remote_side)
-
+ def __create_lazy_clause(cls, prop, reverse_direction=False):
binds = {}
equated_columns = {}
+ secondaryjoin = prop.secondaryjoin
+ equated = dict(prop.equated_pairs)
+
def should_bind(targetcol, othercol):
- if not prop._col_is_part_of_mappings(targetcol):
- return False
-
if reverse_direction and not secondaryjoin:
- return targetcol in remote_side
+ return othercol in equated
else:
- return othercol in remote_side
+ return targetcol in equated
def visit_binary(binary):
- if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
- return
leftcol = binary.left
rightcol = binary.right
equated_columns[leftcol] = rightcol
if should_bind(leftcol, rightcol):
- if leftcol in binds:
- binary.left = binds[leftcol]
- else:
- binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+ if leftcol not in binds:
+ binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+ binary.left = binds[leftcol]
+ elif should_bind(rightcol, leftcol):
+ if rightcol not in binds:
+ binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
+ binary.right = binds[rightcol]
- # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
- # which can happen in rare cases (test/orm/relationships.py RelationTest2)
- if leftcol is not rightcol and should_bind(rightcol, leftcol):
- if rightcol in binds:
- binary.right = binds[rightcol]
- else:
- binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
-
-
- lazywhere = primaryjoin
+ lazywhere = prop.primaryjoin
- if not secondaryjoin or not reverse_direction:
+ if not prop.secondaryjoin or not reverse_direction:
lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary)
- if secondaryjoin is not None:
+ if prop.secondaryjoin is not None:
if reverse_direction:
secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
lazywhere = sql.and_(lazywhere, secondaryjoin)
- return (lazywhere, binds, equated_columns)
- _create_lazy_clause = classmethod(_create_lazy_clause)
+
+ bind_to_col = dict([(binds[col].key, col) for col in binds])
+
+ return (lazywhere, bind_to_col, equated_columns)
+ __create_lazy_clause = classmethod(__create_lazy_clause)
LazyLoader.logger = logging.class_logger(LazyLoader)
ident = []
allnulls = True
for primary_key in prop.mapper.primary_key:
- val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+ val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key])
allnulls = allnulls and val is None
ident.append(val)
if allnulls:
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Contains the ClauseSynchronizer class, which is used to map
-attributes between two objects in a manner corresponding to a SQL
-clause that compares column values.
+"""private module containing functions used for copying data between instances
+based on join conditions.
"""
from sqlalchemy import schema, exceptions, util
-from sqlalchemy.sql import visitors, operators
+from sqlalchemy.sql import visitors, operators, util as sqlutil
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY # legacy
-ONETOMANY = 0
-MANYTOONE = 1
-MANYTOMANY = 2
-
-class ClauseSynchronizer(object):
- """Given a SQL clause, usually a series of one or more binary
- expressions between columns, and a set of 'source' and
- 'destination' mappers, compiles a set of SyncRules corresponding
- to that information.
-
- The ClauseSynchronizer can then be executed given a set of
- parent/child objects or destination dictionary, which will iterate
- through each of its SyncRules and execute them. Each SyncRule
- will copy the value of a single attribute from the parent to the
- child, corresponding to the pair of columns in a particular binary
- expression, using the source and destination mappers to map those
- two columns to object attributes within parent and child.
- """
-
- def __init__(self, parent_mapper, child_mapper, direction):
- self.parent_mapper = parent_mapper
- self.child_mapper = child_mapper
- self.direction = direction
- self.syncrules = []
-
- def compile(self, sqlclause, foreign_keys=None, issecondary=None):
- def compile_binary(binary):
- """Assemble a SyncRule given a single binary condition."""
-
- if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
- return
-
- source_column = None
- dest_column = None
-
- if foreign_keys is None:
- if binary.left.table == binary.right.table:
- raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")
-
- if binary.left in util.Set([f.column for f in binary.right.foreign_keys]):
- dest_column = binary.right
- source_column = binary.left
- elif binary.right in util.Set([f.column for f in binary.left.foreign_keys]):
- dest_column = binary.left
- source_column = binary.right
- else:
- if binary.left in foreign_keys:
- source_column = binary.right
- dest_column = binary.left
- elif binary.right in foreign_keys:
- source_column = binary.left
- dest_column = binary.right
+def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ value = source_mapper._get_state_attr_by_column(source, l)
+ except exceptions.UnmappedColumnError:
+ _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
- if source_column and dest_column:
- if self.direction == ONETOMANY:
- self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper))
- elif self.direction == MANYTOONE:
- self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper))
- else:
- if not issecondary:
- self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary))
- else:
- self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
+ try:
+ dest_mapper._set_state_attr_by_column(dest, r, value)
+ except exceptions.UnmappedColumnError:
+ self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
- rules_added = len(self.syncrules)
- visitors.traverse(sqlclause, visit_binary=compile_binary)
- if len(self.syncrules) == rules_added:
- raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
+def clear(dest, dest_mapper, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ if r.primary_key:
+ raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
+ try:
+ dest_mapper._set_state_attr_by_column(dest, r, None)
+ except exceptions.UnmappedColumnError:
+ _raise_col_to_prop(True, None, l, dest_mapper, r)
- def dest_columns(self):
- return [r.dest_column for r in self.syncrules if r.dest_column is not None]
+def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
+ value = source_mapper._get_state_attr_by_column(source, l)
+ except exceptions.UnmappedColumnError:
+ self._raise_col_to_prop(False, source_mapper, l, None, r)
+ dest[r.key] = value
+ dest[old_prefix + r.key] = oldvalue
- def update(self, dest, parent, child, old_prefix):
- for rule in self.syncrules:
- rule.update(dest, parent, child, old_prefix)
-
- def execute(self, source, dest, obj=None, child=None, clearkeys=None):
- for rule in self.syncrules:
- rule.execute(source, dest, obj, child, clearkeys)
-
- def source_changes(self, uowcommit, source):
- for rule in self.syncrules:
- if rule.source_changes(uowcommit, source):
- return True
- else:
- return False
+def populate_dict(source, source_mapper, dict_, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ value = source_mapper._get_state_attr_by_column(source, l)
+ except exceptions.UnmappedColumnError:
+ _raise_col_to_prop(False, source_mapper, l, None, r)
-class SyncRule(object):
- """An instruction indicating how to populate the objects on each
- side of a relationship.
-
- E.g. if table1 column A is joined against table2 column
- B, and we are a one-to-many from table1 to table2, a syncrule
- would say *take the A attribute from object1 and assign it to the
- B attribute on object2*.
- """
+ dict_[r.key] = value
- def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None):
- self.source_mapper = source_mapper
- self.source_column = source_column
- self.issecondary = issecondary
- self.dest_mapper = dest_mapper
- self.dest_column = dest_column
-
- #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
-
- def dest_primary_key(self):
- # late-evaluating boolean since some syncs are created
- # before the mapper has assembled pks
- try:
- return self._dest_primary_key
- except AttributeError:
- self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
- return self._dest_primary_key
-
- def _raise_col_to_prop(self, isdest):
- if isdest:
- raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (self.dest_column, self.dest_mapper))
- else:
- raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (self.source_column, self.source_mapper, self.dest_column))
-
- def source_changes(self, uowcommit, source):
+def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
+ for l, r in synchronize_pairs:
try:
- prop = self.source_mapper._get_col_to_prop(self.source_column)
+ prop = source_mapper._get_col_to_prop(l)
except exceptions.UnmappedColumnError:
- self._raise_col_to_prop(False)
+ _raise_col_to_prop(False, source_mapper, l, None, r)
(added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True)
- return bool(added and deleted)
-
- def update(self, dest, parent, child, old_prefix):
- if self.issecondary is False:
- source = parent
- elif self.issecondary is True:
- source = child
+ if added and deleted:
+ return True
+ else:
+ return False
+
+def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
+ for l, r in synchronize_pairs:
try:
- oldvalue = self.source_mapper._get_committed_attr_by_column(source.obj(), self.source_column)
- value = self.source_mapper._get_state_attr_by_column(source, self.source_column)
+ prop = dest_mapper._get_col_to_prop(r)
except exceptions.UnmappedColumnError:
- self._raise_col_to_prop(False)
- dest[self.dest_column.key] = value
- dest[old_prefix + self.dest_column.key] = oldvalue
+ _raise_col_to_prop(True, None, l, dest_mapper, r)
+ (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True)
+ if added and deleted:
+ return True
+ else:
+ return False
+
+def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
+ if isdest:
+ raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
+ else:
+ raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
- def execute(self, source, dest, parent, child, clearkeys):
- # TODO: break the "dictionary" case into a separate method like 'update' above,
- # reduce conditionals
- if source is None:
- if self.issecondary is False:
- source = parent
- elif self.issecondary is True:
- source = child
- if clearkeys or source is None:
- value = None
- clearkeys = True
- else:
- try:
- value = self.source_mapper._get_state_attr_by_column(source, self.source_column)
- except exceptions.UnmappedColumnError:
- self._raise_col_to_prop(False)
- if isinstance(dest, dict):
- dest[self.dest_column.key] = value
- else:
- if clearkeys and self.dest_primary_key():
- raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.state_str(dest)))
-
- if logging.is_debug_enabled(self.logger):
- self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.state_str(source), str(self.source_column), mapperutil.state_str(dest), str(self.dest_column), value))
- try:
- self.dest_mapper._set_state_attr_by_column(dest, self.dest_column, value)
- except exceptions.UnmappedColumnError:
- self._raise_col_to_prop(True)
-
-SyncRule.logger = logging.class_logger(SyncRule)
-
def references(self, column):
"""Return True if this references the given column via a foreign key."""
for fk in self.foreign_keys:
- if fk.column is column:
+ if fk.references(column.table):
return True
else:
return False
-from sqlalchemy import exceptions, schema, topological, util
+from sqlalchemy import exceptions, schema, topological, util, sql
from sqlalchemy.sql import expression, operators, visitors
from itertools import chain
"""Utility functions that build upon SQL and Schema constructs."""
def sort_tables(tables, reverse=False):
+ """sort a collection of Table objects in order of their foreign-key dependency."""
+
tuples = []
class TVisitor(schema.SchemaVisitor):
def visit_foreign_key(_self, fkey):
return sequence
def find_tables(clause, check_columns=False, include_aliases=False):
+ """locate Table objects within the given expression."""
+
tables = []
kwargs = {}
if include_aliases:
return tables
def find_columns(clause):
+ """locate Column objects within the given expression."""
+
cols = util.Set()
def visit_column(col):
cols.add(col)
return expression.ColumnSet(columns.difference(omit))
+def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False):
+ """traverse an expression and locate binary criterion pairs."""
+
+ if consider_as_foreign_keys and consider_as_referenced_keys:
+ raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+
+ def visit_binary(binary):
+ if not any_operator and binary.operator != operators.eq:
+ return
+ if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
+ return
+
+ if consider_as_foreign_keys:
+ if binary.left in consider_as_foreign_keys:
+ pairs.append((binary.right, binary.left))
+ elif binary.right in consider_as_foreign_keys:
+ pairs.append((binary.left, binary.right))
+ elif consider_as_referenced_keys:
+ if binary.left in consider_as_referenced_keys:
+ pairs.append((binary.left, binary.right))
+ elif binary.right in consider_as_referenced_keys:
+ pairs.append((binary.right, binary.left))
+ else:
+ if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+ if binary.left.references(binary.right):
+ pairs.append((binary.right, binary.left))
+ elif binary.right.references(binary.left):
+ pairs.append((binary.left, binary.right))
+ pairs = []
+ visitors.traverse(expression, visit_binary=visit_binary)
+ return pairs
+
class AliasedRow(object):
def __init__(self, row, map):
return self.row.keys()
def row_adapter(from_, equivalent_columns=None):
- """create a row adapter against a selectable."""
+ """create a row adapter callable against a selectable."""
if equivalent_columns is None:
equivalent_columns = {}
collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
self._test_composite_mapped(collection_class)
+# TODO: are these tests redundant vs. the above tests ?
+# remove if so
+class CustomCollectionsTest(ORMTest):
+ def define_tables(self, metadata):
+ global sometable, someothertable
+ sometable = Table('sometable', metadata,
+ Column('col1',Integer, primary_key=True),
+ Column('data', String(30)))
+ someothertable = Table('someothertable', metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('scol1', Integer, ForeignKey(sometable.c.col1)),
+ Column('data', String(20))
+ )
+ def test_basic(self):
+ class MyList(list):
+ pass
+ class Foo(object):
+ pass
+ class Bar(object):
+ pass
+ mapper(Foo, sometable, properties={
+ 'bars':relation(Bar, collection_class=MyList)
+ })
+ mapper(Bar, someothertable)
+ f = Foo()
+ assert isinstance(f.bars, MyList)
+
+ def test_lazyload(self):
+ """test that a 'set' can be used as a collection and can lazyload."""
+ class Foo(object):
+ pass
+ class Bar(object):
+ pass
+ mapper(Foo, sometable, properties={
+ 'bars':relation(Bar, collection_class=set)
+ })
+ mapper(Bar, someothertable)
+ f = Foo()
+ f.bars.add(Bar())
+ f.bars.add(Bar())
+ sess = create_session()
+ sess.save(f)
+ sess.flush()
+ sess.clear()
+ f = sess.query(Foo).get(f.col1)
+ assert len(list(f.bars)) == 2
+ f.bars.clear()
+
+ def test_dict(self):
+ """test that a 'dict' can be used as a collection and can lazyload."""
+
+ class Foo(object):
+ pass
+ class Bar(object):
+ pass
+ class AppenderDict(dict):
+ @collection.appender
+ def set(self, item):
+ self[id(item)] = item
+ @collection.remover
+ def remove(self, item):
+ if id(item) in self:
+ del self[id(item)]
+
+ mapper(Foo, sometable, properties={
+ 'bars':relation(Bar, collection_class=AppenderDict)
+ })
+ mapper(Bar, someothertable)
+ f = Foo()
+ f.bars.set(Bar())
+ f.bars.set(Bar())
+ sess = create_session()
+ sess.save(f)
+ sess.flush()
+ sess.clear()
+ f = sess.query(Foo).get(f.col1)
+ assert len(list(f.bars)) == 2
+ f.bars.clear()
+
+ def test_dict_wrapper(self):
+ """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
+
+ class Foo(object):
+ pass
+ class Bar(object):
+ def __init__(self, data): self.data = data
+
+ mapper(Foo, sometable, properties={
+ 'bars':relation(Bar,
+ collection_class=collections.column_mapped_collection(someothertable.c.data))
+ })
+ mapper(Bar, someothertable)
+
+ f = Foo()
+ col = collections.collection_adapter(f.bars)
+ col.append_with_event(Bar('a'))
+ col.append_with_event(Bar('b'))
+ sess = create_session()
+ sess.save(f)
+ sess.flush()
+ sess.clear()
+ f = sess.query(Foo).get(f.col1)
+ assert len(list(f.bars)) == 2
+
+ existing = set([id(b) for b in f.bars.values()])
+
+ col = collections.collection_adapter(f.bars)
+ col.append_with_event(Bar('b'))
+ f.bars['a'] = Bar('a')
+ sess.flush()
+ sess.clear()
+ f = sess.query(Foo).get(f.col1)
+ assert len(list(f.bars)) == 2
+
+ replaced = set([id(b) for b in f.bars.values()])
+ self.assert_(existing != replaced)
+
+ def test_list(self):
+ class Parent(object):
+ pass
+ class Child(object):
+ pass
+
+ mapper(Parent, sometable, properties={
+ 'children':relation(Child, collection_class=list)
+ })
+ mapper(Child, someothertable)
+
+ control = list()
+ p = Parent()
+
+ o = Child()
+ control.append(o)
+ p.children.append(o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = [Child(), Child(), Child(), Child()]
+ control.extend(o)
+ p.children.extend(o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ assert control[0] == p.children[0]
+ assert control[-1] == p.children[-1]
+ assert control[1:3] == p.children[1:3]
+
+ del control[1]
+ del p.children[1]
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = [Child()]
+ control[1:3] = o
+ p.children[1:3] = o
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = [Child(), Child(), Child(), Child()]
+ control[1:3] = o
+ p.children[1:3] = o
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = [Child(), Child(), Child(), Child()]
+ control[-1:-2] = o
+ p.children[-1:-2] = o
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = [Child(), Child(), Child(), Child()]
+ control[4:] = o
+ p.children[4:] = o
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = Child()
+ control.insert(0, o)
+ p.children.insert(0, o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = Child()
+ control.insert(3, o)
+ p.children.insert(3, o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = Child()
+ control.insert(999, o)
+ p.children.insert(999, o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ del control[0:1]
+ del p.children[0:1]
+ assert control == p.children
+ assert control == list(p.children)
+
+ del control[1:1]
+ del p.children[1:1]
+ assert control == p.children
+ assert control == list(p.children)
+
+ del control[1:3]
+ del p.children[1:3]
+ assert control == p.children
+ assert control == list(p.children)
+
+ del control[7:]
+ del p.children[7:]
+ assert control == p.children
+ assert control == list(p.children)
+
+ assert control.pop() == p.children.pop()
+ assert control == p.children
+ assert control == list(p.children)
+
+ assert control.pop(0) == p.children.pop(0)
+ assert control == p.children
+ assert control == list(p.children)
+
+ assert control.pop(2) == p.children.pop(2)
+ assert control == p.children
+ assert control == list(p.children)
+
+ o = Child()
+ control.insert(2, o)
+ p.children.insert(2, o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ control.remove(o)
+ p.children.remove(o)
+ assert control == p.children
+ assert control == list(p.children)
+
+ def test_custom(self):
+ class Parent(object):
+ pass
+ class Child(object):
+ pass
+
+ class MyCollection(object):
+ def __init__(self):
+ self.data = []
+ @collection.appender
+ def append(self, value):
+ self.data.append(value)
+ @collection.remover
+ def remove(self, value):
+ self.data.remove(value)
+ @collection.iterator
+ def __iter__(self):
+ return iter(self.data)
+
+ mapper(Parent, sometable, properties={
+ 'children':relation(Child, collection_class=MyCollection)
+ })
+ mapper(Child, someothertable)
+
+ control = list()
+ p1 = Parent()
+
+ o = Child()
+ control.append(o)
+ p1.children.append(o)
+ assert control == list(p1.children)
+
+ o = Child()
+ control.append(o)
+ p1.children.append(o)
+ assert control == list(p1.children)
+
+ o = Child()
+ control.append(o)
+ p1.children.append(o)
+ assert control == list(p1.children)
+
+ sess = create_session()
+ sess.save(p1)
+ sess.flush()
+ sess.clear()
+
+ p2 = sess.query(Parent).get(p1.col1)
+ o = list(p2.children)
+ assert len(o) == 3
+
if __name__ == "__main__":
testenv.main()
pass
class Manager(Person):
pass
-
- mapper(Person, people, properties={
- 'manager':relation(Manager, primaryjoin=people.c.manager_id==managers.c.person_id, uselist=False)
- })
- mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id)
-
- self.assertRaisesMessage(exceptions.ArgumentError,
- r"Can't determine relation direction for relationship 'Person\.manager \(Manager\)' - foreign key columns are present in both the parent and the child's mapped tables\. Specify 'foreign_keys' argument\.",
- compile_mappers
- )
- clear_mappers()
-
+
+ # note that up until recently (0.4.4), we had to specify "foreign_keys" here
+ # for this primary join.
mapper(Person, people, properties={
'manager':relation(Manager, primaryjoin=(people.c.manager_id ==
managers.c.person_id),
- foreign_keys=[people.c.manager_id],
uselist=False, post_update=True)
})
mapper(Manager, managers, inherits=Person,
inherit_condition=people.c.person_id==managers.c.person_id)
-
+
+ self.assertEquals(class_mapper(Person).get_property('manager').foreign_keys, set([people.c.manager_id]))
+
session = create_session()
p = Person(name='some person')
m = Manager(name='some manager')
from sqlalchemy.orm import collections
from sqlalchemy.orm.collections import collection
from testlib import *
+from testlib import fixtures
class RelationTest(TestBase):
"""An extended topological sort test
assert t3.count().scalar() == 1
-# TODO: move these tests to either attributes.py test or its own module
-class CustomCollectionsTest(ORMTest):
- def define_tables(self, metadata):
- global sometable, someothertable
- sometable = Table('sometable', metadata,
- Column('col1',Integer, primary_key=True),
- Column('data', String(30)))
- someothertable = Table('someothertable', metadata,
- Column('col1', Integer, primary_key=True),
- Column('scol1', Integer, ForeignKey(sometable.c.col1)),
- Column('data', String(20))
- )
- def testbasic(self):
- class MyList(list):
- pass
- class Foo(object):
- pass
- class Bar(object):
- pass
- mapper(Foo, sometable, properties={
- 'bars':relation(Bar, collection_class=MyList)
- })
- mapper(Bar, someothertable)
- f = Foo()
- assert isinstance(f.bars, MyList)
- def testlazyload(self):
- """test that a 'set' can be used as a collection and can lazyload."""
- class Foo(object):
- pass
- class Bar(object):
- pass
- mapper(Foo, sometable, properties={
- 'bars':relation(Bar, collection_class=set)
- })
- mapper(Bar, someothertable)
- f = Foo()
- f.bars.add(Bar())
- f.bars.add(Bar())
- sess = create_session()
- sess.save(f)
- sess.flush()
- sess.clear()
- f = sess.query(Foo).get(f.col1)
- assert len(list(f.bars)) == 2
- f.bars.clear()
-
- def testdict(self):
- """test that a 'dict' can be used as a collection and can lazyload."""
-
- class Foo(object):
- pass
- class Bar(object):
- pass
- class AppenderDict(dict):
- @collection.appender
- def set(self, item):
- self[id(item)] = item
- @collection.remover
- def remove(self, item):
- if id(item) in self:
- del self[id(item)]
-
- mapper(Foo, sometable, properties={
- 'bars':relation(Bar, collection_class=AppenderDict)
- })
- mapper(Bar, someothertable)
- f = Foo()
- f.bars.set(Bar())
- f.bars.set(Bar())
- sess = create_session()
- sess.save(f)
- sess.flush()
- sess.clear()
- f = sess.query(Foo).get(f.col1)
- assert len(list(f.bars)) == 2
- f.bars.clear()
-
- def testdictwrapper(self):
- """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
-
- class Foo(object):
- pass
- class Bar(object):
- def __init__(self, data): self.data = data
-
- mapper(Foo, sometable, properties={
- 'bars':relation(Bar,
- collection_class=collections.column_mapped_collection(someothertable.c.data))
- })
- mapper(Bar, someothertable)
-
- f = Foo()
- col = collections.collection_adapter(f.bars)
- col.append_with_event(Bar('a'))
- col.append_with_event(Bar('b'))
- sess = create_session()
- sess.save(f)
- sess.flush()
- sess.clear()
- f = sess.query(Foo).get(f.col1)
- assert len(list(f.bars)) == 2
-
- existing = set([id(b) for b in f.bars.values()])
-
- col = collections.collection_adapter(f.bars)
- col.append_with_event(Bar('b'))
- f.bars['a'] = Bar('a')
- sess.flush()
- sess.clear()
- f = sess.query(Foo).get(f.col1)
- assert len(list(f.bars)) == 2
-
- replaced = set([id(b) for b in f.bars.values()])
- self.assert_(existing != replaced)
-
- def testlist(self):
- class Parent(object):
- pass
- class Child(object):
- pass
-
- mapper(Parent, sometable, properties={
- 'children':relation(Child, collection_class=list)
- })
- mapper(Child, someothertable)
-
- control = list()
- p = Parent()
-
- o = Child()
- control.append(o)
- p.children.append(o)
- assert control == p.children
- assert control == list(p.children)
-
- o = [Child(), Child(), Child(), Child()]
- control.extend(o)
- p.children.extend(o)
- assert control == p.children
- assert control == list(p.children)
-
- assert control[0] == p.children[0]
- assert control[-1] == p.children[-1]
- assert control[1:3] == p.children[1:3]
-
- del control[1]
- del p.children[1]
- assert control == p.children
- assert control == list(p.children)
-
- o = [Child()]
- control[1:3] = o
- p.children[1:3] = o
- assert control == p.children
- assert control == list(p.children)
-
- o = [Child(), Child(), Child(), Child()]
- control[1:3] = o
- p.children[1:3] = o
- assert control == p.children
- assert control == list(p.children)
-
- o = [Child(), Child(), Child(), Child()]
- control[-1:-2] = o
- p.children[-1:-2] = o
- assert control == p.children
- assert control == list(p.children)
-
- o = [Child(), Child(), Child(), Child()]
- control[4:] = o
- p.children[4:] = o
- assert control == p.children
- assert control == list(p.children)
-
- o = Child()
- control.insert(0, o)
- p.children.insert(0, o)
- assert control == p.children
- assert control == list(p.children)
-
- o = Child()
- control.insert(3, o)
- p.children.insert(3, o)
- assert control == p.children
- assert control == list(p.children)
-
- o = Child()
- control.insert(999, o)
- p.children.insert(999, o)
- assert control == p.children
- assert control == list(p.children)
-
- del control[0:1]
- del p.children[0:1]
- assert control == p.children
- assert control == list(p.children)
-
- del control[1:1]
- del p.children[1:1]
- assert control == p.children
- assert control == list(p.children)
-
- del control[1:3]
- del p.children[1:3]
- assert control == p.children
- assert control == list(p.children)
-
- del control[7:]
- del p.children[7:]
- assert control == p.children
- assert control == list(p.children)
-
- assert control.pop() == p.children.pop()
- assert control == p.children
- assert control == list(p.children)
-
- assert control.pop(0) == p.children.pop(0)
- assert control == p.children
- assert control == list(p.children)
-
- assert control.pop(2) == p.children.pop(2)
- assert control == p.children
- assert control == list(p.children)
-
- o = Child()
- control.insert(2, o)
- p.children.insert(2, o)
- assert control == p.children
- assert control == list(p.children)
-
- control.remove(o)
- p.children.remove(o)
- assert control == p.children
- assert control == list(p.children)
-
- def testobj(self):
- class Parent(object):
- pass
- class Child(object):
- pass
-
- class MyCollection(object):
- def __init__(self):
- self.data = []
- @collection.appender
- def append(self, value):
- self.data.append(value)
- @collection.remover
- def remove(self, value):
- self.data.remove(value)
- @collection.iterator
- def __iter__(self):
- return iter(self.data)
-
- mapper(Parent, sometable, properties={
- 'children':relation(Child, collection_class=MyCollection)
- })
- mapper(Child, someothertable)
-
- control = list()
- p1 = Parent()
-
- o = Child()
- control.append(o)
- p1.children.append(o)
- assert control == list(p1.children)
-
- o = Child()
- control.append(o)
- p1.children.append(o)
- assert control == list(p1.children)
-
- o = Child()
- control.append(o)
- p1.children.append(o)
- assert control == list(p1.children)
-
- sess = create_session()
- sess.save(p1)
- sess.flush()
- sess.clear()
-
- p2 = sess.query(Parent).get(p1.col1)
- o = list(p2.children)
- assert len(o) == 3
+
+
class ViewOnlyTest(ORMTest):
"""test a view_only mapping where a third table is pulled into the primary join condition,
assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id])
assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id])
+class ViewOnlyTest3(ORMTest):
+ def define_tables(self, metadata):
+ global foos, bars
+ foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
+ bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+
+ def test_viewonly_join(self):
+ class Foo(fixtures.Base):
+ pass
+ class Bar(fixtures.Base):
+ pass
+
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=[bars.c.fid], viewonly=True)
+ })
+
+ mapper(Bar, bars)
+
+ sess = create_session()
+ sess.save(Foo(id=4))
+ sess.save(Foo(id=9))
+ sess.save(Bar(id=1, fid=2))
+ sess.save(Bar(id=2, fid=3))
+ sess.save(Bar(id=3, fid=6))
+ sess.save(Bar(id=4, fid=7))
+ sess.flush()
+
+ sess = create_session()
+ self.assertEquals(sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)]))
+ self.assertEquals(sess.query(Foo).filter_by(id=9).one(), Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)]))
+
+class InvalidRelationEscalationTest(ORMTest):
+ def define_tables(self, metadata):
+ global foos, bars, Foo, Bar
+ foos = Table('foos', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+ bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+ class Foo(object):
+ pass
+ class Bar(object):
+ pass
+
+ def test_no_join(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+ def test_no_join_self_ref(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+ def test_no_equated(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+ def test_no_equated_fks(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=bars.c.fid)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for primaryjoin condition", compile_mappers)
+
+ def test_no_equated_self_ref(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+ def test_no_equated_self_ref(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, foreign_keys=[foos.c.fid])
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for primaryjoin condition", compile_mappers)
+
+ def test_no_equated_viewonly(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, viewonly=True)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+ def test_no_equated_self_ref_viewonly(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True)
+ })
+
+ mapper(Bar, bars)
+
+ self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
+
+ def test_no_equated_self_ref_viewonly_fks(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True, foreign_keys=[foos.c.fid])
+ })
+ compile_mappers()
+ self.assertEquals(Foo.foos.property.equated_pairs, [(foos.c.id, foos.c.fid)])
+
+ def test_equated(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid)
+ })
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+ def test_equated_self_ref(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)
+ })
+
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+ def test_equated_self_ref_wrong_fks(self):
+ mapper(Foo, foos, properties={
+ 'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])
+ })
+
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+class InvalidRelationEscalationTestM2M(ORMTest):
+ def define_tables(self, metadata):
+ global foos, bars, Foo, Bar, foobars
+ foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
+ foobars = Table('foobars', metadata, Column('fid', Integer), Column('bid', Integer))
+ bars = Table('bars', metadata, Column('id', Integer, primary_key=True))
+ class Foo(object):
+ pass
+ class Bar(object):
+ pass
+
+ def test_no_join(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, secondary=foobars)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+ def test_no_secondaryjoin(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+ def test_bad_primaryjoin(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id)
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+ def test_bad_secondaryjoin(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid])
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
+
+ def test_no_equated_secondaryjoin(self):
+ mapper(Foo, foos, properties={
+ 'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid, foobars.c.bid])
+ })
+
+ mapper(Bar, bars)
+ self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for secondaryjoin condition", compile_mappers)
+
if __name__ == "__main__":
testenv.main()