relationship where it takes effect for all inheriting mappers.
[ticket:883]
+ - several ORM attributes have been removed or made private:
+ mapper.get_attr_by_column(), mapper.set_attr_by_column(),
+ mapper.pks_by_table, mapper.cascade_callable(),
+ MapperProperty.cascade_callable(), mapper.canload()
+
- fixed endless loop issue when using lazy="dynamic" on both
sides of a bi-directional relationship [ticket:872]
if isinstance(mapping_spec, schema.Column):
def keyfunc(value):
m = object_mapper(value)
- return m.get_attr_by_column(value, mapping_spec)
+ return m._get_attr_by_column(value, mapping_spec)
else:
cols = []
for c in mapping_spec:
mapping_spec = tuple(cols)
def keyfunc(value):
m = object_mapper(value)
- return tuple([m.get_attr_by_column(value, c) for c in mapping_spec])
+ return tuple([m._get_attr_by_column(value, c) for c in mapping_spec])
return lambda: MappedCollection(keyfunc)
def attribute_mapped_collection(attr_name):
def _verify_canload(self, child):
if not self.enable_typechecks:
return
- if child is not None and not self.mapper.canload(child):
+ if child is not None and not self.mapper._canload(child):
raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (child.__class__, self.prop, self.mapper))
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
uowcommit.register_object(child, isdelete=False)
elif self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
for child in childlist.deleted_items() + childlist.unchanged_items():
if child is not None and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
else:
for obj in deplist:
for child in childlist.deleted_items():
if self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
for child in childlist.deleted_items():
if self.cascade.delete_orphan and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
- for c in self.mapper.cascade_iterator('delete', child):
+ for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
"""
raise NotImplementedError()
-
+
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
- """return an iterator of objects which are child objects of the given object,
- as attached to the attribute corresponding to this MapperProperty."""
+ """iterate through instances related to the given instance along
+ a particular 'cascade' path, starting with this MapperProperty.
- return []
-
- def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
- """run the given callable across all objects which are child objects of
- the given object, as attached to the attribute corresponding to this MapperProperty."""
+ see PropertyLoader for the related instance implementation.
+ """
- return []
+ return iter([])
def get_criterion(self, query, key, value):
"""Return a ``WHERE`` clause suitable for this
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import weakref, warnings, operator
+import weakref, warnings
+from itertools import chain
from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression, visitors
+from sqlalchemy.sql import expression, visitors, operators
from sqlalchemy.sql import util as sqlutil
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
from sqlalchemy.orm import sync, attributes
from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
-deferred_load = None
__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
# a list of MapperExtensions that will be installed in all mappers by default
global_extensions = []
-# a constant returned by get_attr_by_column to indicate
+# a constant returned by _get_attr_by_column to indicate
# this mapper is not handling an attribute for a particular
# column
NO_ATTRIBUTE = object()
self._compile_inheritance()
self._compile_tables()
self._compile_properties()
+ self._compile_pks()
self._compile_selectable()
self.__log("constructed")
self.polymorphic_map[key] = class_or_mapper
def _compile_tables(self):
- """After the inheritance relationships have been reconciled,
- set up some more table-based instance variables and determine
- the *primary key* columns for all tables represented by this
- ``Mapper``.
- """
-
# summary of the various Selectable units:
# mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table)
# local_table - the Selectable that was passed to this Mapper's constructor, if any
if not self.tables:
raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
- # TODO: move the "figure pks" step down into compile_properties; after
- # all columns have been mapped, assemble PK columns and their
- # proxied parents into the pks_by_table collection, then get rid
- # of the _has_pks method
-
- # determine primary key columns
- self.pks_by_table = {}
+ def _compile_pks(self):
- # go through all of our represented tables
- # and assemble primary key columns
- for t in self.tables + [self.mapped_table]:
+ self._pks_by_table = {}
+ self._cols_by_table = {}
+
+ all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]]))
+ pk_cols = util.Set([c for c in all_cols if c.primary_key])
+
+ for t in util.Set(self.tables + [self.mapped_table]):
self._all_tables.add(t)
- if t not in self.pks_by_table:
- self.pks_by_table[t] = util.OrderedSet()
- self.pks_by_table[t].update(t.primary_key)
-
- if self.primary_key_argument is not None:
+ if t.primary_key and pk_cols.issuperset(t.primary_key):
+ self._pks_by_table[t] = util.Set(t.primary_key).intersection(pk_cols)
+ self._cols_by_table[t] = util.Set(t.c).intersection(all_cols)
+
+ if self.primary_key_argument:
for k in self.primary_key_argument:
- self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
+ self._pks_by_table.setdefault(k.table, util.Set()).add(k)
- if len(self.pks_by_table[self.mapped_table]) == 0:
+ if len(self._pks_by_table[self.mapped_table]) == 0:
raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
if self.inherits is not None and not self.concrete and not self.primary_key_argument:
primary_key = expression.ColumnSet()
- for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
c = self.mapped_table.corresponding_column(col, raiseerr=False)
if c is None:
for cc in self._equivalent_columns[col]:
self.primary_key = primary_key
self.__log("Identified primary key columns: " + str(primary_key))
-
+
+ # create a "get clause" based on the primary key. this is used
+ # by query.get() and many-to-one lazyloads to load this item
+ # by primary key.
_get_clause = sql.and_()
_get_params = {}
for primary_key in self.primary_key:
result = {}
def visit_binary(binary):
- if binary.operator == operator.eq:
+ if binary.operator == operators.eq:
if binary.left in result:
result[binary.left].add(binary.right)
else:
return
recursive.add(col)
for fk in col.foreign_keys:
- result.setdefault(fk.column, util.Set()).add(equiv)
+ if fk.column not in result:
+ result[fk.column] = util.Set()
+ result[fk.column].add(equiv)
equivs(fk.column, recursive, col)
- for column in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+ for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
for col in column.proxy_set:
if not col.foreign_keys:
- result.setdefault(col, util.Set()).add(col)
+ if col not in result:
+ result[col] = util.Set()
+ result[col].add(col)
else:
equivs(col, util.Set(), col)
return getattr(getattr(cls, clskey), key)
def _compile_properties(self):
- """Inspect the properties dictionary sent to the Mapper's
- constructor as well as the mapped_table, and create
- ``MapperProperty`` objects corresponding to each mapped column
- and relation.
- """
# object attribute names mapped to MapperProperty objects
self.__props = util.OrderedDict()
# TODO: the "property already exists" case is still not well defined here.
# assuming single-column, etc.
- if column in self.primary_key and prop.columns[-1] in self.primary_key:
- warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), key)))
-
if prop.parent is not self:
# existing ColumnProperty from an inheriting mapper.
# make a copy and append our column to it
instance.
"""
- return [self.get_attr_by_column(instance, column) for column in self.primary_key]
+ return [self._get_attr_by_column(instance, column) for column in self.primary_key]
- def canload(self, instance):
+ def _canload(self, instance):
"""return true if this mapper is capable of loading the given instance"""
if self.polymorphic_on is not None:
return isinstance(instance, self.class_)
else:
return instance.__class__ is self.class_
- def _getpropbycolumn(self, column, raiseerror=True):
+ def _get_attr_by_column(self, obj, column):
+ """Return an instance attribute using a Column as the key."""
try:
- return self._columntoproperty[column]
+ return self._columntoproperty[column].getattr(obj, column)
except KeyError:
- try:
- prop = self.__props[column.key]
- if not raiseerror:
- return None
+ prop = self.__props.get(column.key, None)
+ if prop:
raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
- except KeyError:
- if not raiseerror:
- return None
+ else:
raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
-
- def get_attr_by_column(self, obj, column, raiseerror=True):
- """Return an instance attribute using a Column as the key."""
-
- prop = self._getpropbycolumn(column, raiseerror)
- if prop is None:
- return NO_ATTRIBUTE
- return prop.getattr(obj, column)
-
- def set_attr_by_column(self, obj, column, value):
+
+ def _set_attr_by_column(self, obj, column, value):
"""Set the value of an instance attribute using a Column as the key."""
self._columntoproperty[column].setattr(obj, value, column)
table_to_mapper = {}
for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
- table_to_mapper.setdefault(t, mapper)
+ table_to_mapper[t] = mapper
- for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False):
+ for table in sqlutil.sort_tables(table_to_mapper.keys()):
# two lists to store parameters for each table/object pair located
insert = []
update = []
for obj, connection in tups:
mapper = object_mapper(obj)
- if table not in mapper.tables or not mapper._has_pks(table):
+ if table not in mapper._pks_by_table:
continue
- pks = mapper.pks_by_table[table]
+ pks = mapper._pks_by_table[table]
instance_key = mapper.identity_key_from_instance(obj)
if self.__should_log_debug:
hasdata = False
if isinsert:
- for col in table.columns:
+ for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = 1
elif col in pks:
- value = mapper.get_attr_by_column(obj, col)
+ value = mapper._get_attr_by_column(obj, col)
if value is not None:
params[col.key] = value
elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
if col.default is None or value is not None:
params[col.key] = value
else:
- value = mapper.get_attr_by_column(obj, col, False)
- if value is NO_ATTRIBUTE:
- continue
+ value = mapper._get_attr_by_column(obj, col)
if col.default is None or value is not None:
if isinstance(value, sql.ClauseElement):
value_params[col] = value
params[col.key] = value
insert.append((obj, params, mapper, connection, value_params))
else:
- for col in table.columns:
+ for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
- params[col._label] = mapper.get_attr_by_column(obj, col)
+ params[col._label] = mapper._get_attr_by_column(obj, col)
params[col.key] = params[col._label] + 1
for prop in mapper._columntoproperty.values():
history = attributes.get_history(obj, prop.key, passive=True)
if history and history.added_items():
hasdata = True
elif col in pks:
- params[col._label] = mapper.get_attr_by_column(obj, col)
+ params[col._label] = mapper._get_attr_by_column(obj, col)
elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
pass
else:
if post_update_cols is not None and col not in post_update_cols:
continue
- prop = mapper._getpropbycolumn(col, False)
- if prop is None:
- continue
+ prop = mapper._columntoproperty[col]
history = attributes.get_history(obj, prop.key, passive=True)
if history:
a = history.added_items()
if update:
mapper = table_to_mapper[table]
clause = sql.and_()
- for col in mapper.pks_by_table[table]:
+ for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
- pks = mapper.pks_by_table[table]
+ pks = mapper._pks_by_table[table]
def comparator(a, b):
for col in pks:
x = cmp(a[1][col._label],b[1][col._label])
if primary_key is not None:
i = 0
- for col in mapper.pks_by_table[table]:
- if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
- mapper.set_attr_by_column(obj, col, primary_key[i])
+ for col in mapper._pks_by_table[table]:
+ if mapper._get_attr_by_column(obj, col) is None and len(primary_key) > i:
+ mapper._set_attr_by_column(obj, col, primary_key[i])
i+=1
mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params)
# synchronize newly inserted ids from one table to the next
# TODO: this fires off more than needed, try to organize syncrules
# per table
- mappers = list(mapper.iterate_to_root())
- mappers.reverse()
- for m in mappers:
+ for m in util.reversed(list(mapper.iterate_to_root())):
if m._synchronizer is not None:
m._synchronizer.execute(obj, obj)
# testlib.pragma exempt:__hash__
inserted_objects.add((id(obj), obj, connection))
+
if not postupdate:
for id_, obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
for mapper in object_mapper(obj).iterate_to_root():
if 'after_update' in mapper.extension.methods:
mapper.extension.after_update(mapper, connection, obj)
-
+
def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
"""After an ``INSERT`` or ``UPDATE``, assemble newly generated
values on an instance. For columns which are marked as being generated
postfetch_cols = resultproxy.postfetch_cols().union(util.Set(value_params.keys()))
deferred_props = []
- for c in table.c:
+ for c in self._cols_by_table[table]:
if c in postfetch_cols and (not c.key in params or c in value_params):
- prop = self._getpropbycolumn(c, raiseerror=False)
- if prop is None:
- continue
+ prop = self._columntoproperty[c]
deferred_props.append(prop.key)
continue
if c.primary_key or not c.key in params:
continue
- v = self.get_attr_by_column(obj, c, False)
- if v is NO_ATTRIBUTE:
- continue
- elif v != params[c.key]:
- self.set_attr_by_column(obj, c, params[c.key])
+ if self._get_attr_by_column(obj, c) != params[c.key]:
+ self._set_attr_by_column(obj, c, params[c.key])
if deferred_props:
expire_instance(obj, deferred_props)
table_to_mapper = {}
for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
- table_to_mapper.setdefault(t, mapper)
+ table_to_mapper[t] = mapper
for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
delete = {}
for (obj, connection) in tups:
mapper = object_mapper(obj)
- if table not in mapper.tables or not mapper._has_pks(table):
+ if table not in mapper._pks_by_table:
continue
params = {}
continue
else:
delete.setdefault(connection, []).append(params)
- for col in mapper.pks_by_table[table]:
- params[col.key] = mapper.get_attr_by_column(obj, col)
+ for col in mapper._pks_by_table[table]:
+ params[col.key] = mapper._get_attr_by_column(obj, col)
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
+ params[mapper.version_id_col.key] = mapper._get_attr_by_column(obj, mapper.version_id_col)
# testlib.pragma exempt:__hash__
deleted_objects.add((id(obj), obj, connection))
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
- for col in mapper.pks_by_table[table]:
+ for col in mapper._pks_by_table[table]:
x = cmp(a[col.key],b[col.key])
if x != 0:
return x
return 0
del_objects.sort(comparator)
clause = sql.and_()
- for col in mapper.pks_by_table[table]:
+ for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
if 'after_delete' in mapper.extension.methods:
mapper.extension.after_delete(mapper, connection, obj)
- def _has_pks(self, table):
- # TODO: determine this beforehand
- if self.pks_by_table.get(table, None):
- for k in self.pks_by_table[table]:
- if k not in self._columntoproperty:
- return False
- else:
- return True
- else:
- return False
-
def register_dependencies(self, uowcommit, *args, **kwargs):
"""Register ``DependencyProcessor`` instances with a
``unitofwork.UOWTransaction``.
prop.register_dependencies(uowcommit, *args, **kwargs)
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
- """Iterate each element in an object graph, for all relations
- taht meet the given cascade rule.
+ """Iterate each element and its mapper in an object graph,
+ for all relations that meet the given cascade rule.
type
The name of the cascade rule (i.e. save-update, delete,
if recursive is None:
recursive=util.IdentitySet()
for prop in self.__props.values():
- for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
- yield c
-
- def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
- """Execute a callable for each element in an object graph, for
- all relations that meet the given cascade rule.
-
- type
- The name of the cascade rule (i.e. save-update, delete, etc.)
-
- object
- The lead object instance. child items will be processed per
- the relations defined for this object's mapper.
-
- callable\_
- The callable function.
-
- recursive
- Used by the function for internal context during recursive
- calls, leave as None.
-
- """
-
- if recursive is None:
- recursive=util.IdentitySet()
- for prop in self.__props.values():
- prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
+ for (c, m) in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
+ yield (c, m)
def get_select_mapper(self):
"""Return the mapper used for issuing selects.
isnew = False
- if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
- raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
+ if context.version_check and self.version_id_col is not None and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
+ raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
if context.populate_existing or self.always_refresh or instance._state.trigger is not None:
instance._state.trigger = None
params = {}
for c in param_names:
- params[c.name] = self.get_attr_by_column(instance, c)
+ params[c.name] = self._get_attr_by_column(instance, c)
row = selectcontext.session.connection(self).execute(statement, params).fetchone()
self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
"""
from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql import util as sql_util, visitors
+from sqlalchemy.sql import util as sql_util, visitors, operators
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import operator
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
from sqlalchemy.exceptions import ArgumentError
+import warnings
__all__ = ['ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'PropertyLoader', 'BackRef']
return strategies.DeferredColumnLoader(self)
else:
return strategies.ColumnLoader(self)
-
+
+ def do_init(self):
+ super(ColumnProperty, self).do_init()
+ if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns):
+ warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self.parent), str(self.columns[1]), str(self.columns[0]), self.key)))
+
def copy(self):
return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
col = self.prop.columns[0]
return op(col._bind_param(other), col)
-
ColumnProperty.logger = logging.class_logger(ColumnProperty)
-
class CompositeProperty(ColumnProperty):
"""subclasses ColumnProperty to provide composite type support."""
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+
+ def do_init(self):
+ super(ColumnProperty, self).do_init()
+ # TODO: similar PK check as ColumnProperty does ?
def copy(self):
return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
def compare(self, op, value, value_is_parent=False):
- if op == operator.eq:
+ if op == operators.eq:
if value is None:
return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
else:
if not isinstance(c, self.mapper.class_):
raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
recursive.add(c)
- yield c
- for c2 in mapper.cascade_iterator(type, c, recursive):
- yield c2
-
- def cascade_callable(self, type, object, callable_, recursive, halt_on=None):
- if not type in self.cascade:
- return
-
- mapper = self.mapper.primary_mapper()
- passive = type != 'delete' or self.passive_deletes
- for c in attributes.get_as_list(object, self.key, passive=passive):
- if c is not None and c not in recursive and (halt_on is None or not halt_on(c)):
- if not isinstance(c, self.mapper.class_):
- raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
- recursive.add(c)
- callable_(c, mapper.entity_name)
- mapper.cascade_callable(type, c, callable_, recursive)
+ yield (c, mapper)
+ for (c2, m) in mapper.cascade_iterator(type, c, recursive):
+ yield (c2, m)
def _get_target_class(self):
"""Return the target class of the relation, even if the
if self.foreign_keys:
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
if binary.left in self.foreign_keys:
self._opposite_side.add(binary.right)
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
# this check is for when the user put the "view_only" flag on and has tables that have nothing
-# objectstore.py
+# session.py
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
Note that this may not be per-flush if a longer running transaction is ongoing."""
- def before_flush(self, session, flush_context, objects):
+ def before_flush(self, session, flush_context, instances):
"""execute before flush process has started.
- 'objects' is an optional list of objects which were passed to the ``flush()``
+ 'instances' is an optional list of objects which were passed to the ``flush()``
method.
"""
entity_name = kwargs.pop('entity_name', None)
return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
- def refresh(self, obj, attribute_names=None):
+ def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
When called, a query will be issued
refreshed.
"""
- self._validate_persistent(obj)
+ self._validate_persistent(instance)
- if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
+ if self.query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
+ raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- def expire(self, obj, attribute_names=None):
+ def expire(self, instance, attribute_names=None):
"""Expire the attributes on the given instance.
The instance's attributes are instrumented such that
"""
if attribute_names:
- self._validate_persistent(obj)
- expire_instance(obj, attribute_names=attribute_names)
+ self._validate_persistent(instance)
+ expire_instance(instance, attribute_names=attribute_names)
else:
- for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
- self._validate_persistent(obj)
+ # pre-fetch the full cascade since the expire is going to
+ # remove associations
+ cascaded = list(_cascade_iterator('refresh-expire', instance))
+ self._validate_persistent(instance)
+ expire_instance(instance, None)
+ for (c, m) in cascaded:
+ self._validate_persistent(c)
expire_instance(c, None)
def prune(self):
return self.uow.prune_identity_map()
- def expunge(self, object):
- """Remove the given `object` from this ``Session``.
+ def expunge(self, instance):
+ """Remove the given `instance` from this ``Session``.
- This will free all internal references to the object.
+ This will free all internal references to the instance.
Cascading will be applied according to the *expunge* cascade
rule.
"""
- self._validate_persistent(object)
- for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)):
+ self._validate_persistent(instance)
+ for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)):
if c in self:
self.uow._remove_deleted(c)
self._unattach(c)
- def save(self, object, entity_name=None):
+ def save(self, instance, entity_name=None):
"""Add a transient (unsaved) instance to this ``Session``.
This operation cascades the `save_or_update` method to
specific ``Mapper`` used to handle this instance.
"""
- self._save_impl(object, entity_name=entity_name)
- _object_mapper(object).cascade_callable('save-update', object,
- lambda c, e:self._save_or_update_impl(c, e),
- halt_on=lambda c:c in self)
+ self._save_impl(instance, entity_name=entity_name)
+ self._cascade_save_or_update(instance)
- def update(self, object, entity_name=None):
+ def update(self, instance, entity_name=None):
"""Bring the given detached (saved) instance into this
``Session``.
``cascade="save-update"``.
"""
- self._update_impl(object, entity_name=entity_name)
- _object_mapper(object).cascade_callable('save-update', object,
- lambda c, e:self._save_or_update_impl(c, e),
- halt_on=lambda c:c in self)
+ self._update_impl(instance, entity_name=entity_name)
+ self._cascade_save_or_update(instance)
- def save_or_update(self, object, entity_name=None):
- """Save or update the given object into this ``Session``.
+ def save_or_update(self, instance, entity_name=None):
+ """Save or update the given instance into this ``Session``.
The presence of an `_instance_key` attribute on the instance
determines whether to ``save()`` or ``update()`` the instance.
"""
- self._save_or_update_impl(object, entity_name=entity_name)
- _object_mapper(object).cascade_callable('save-update', object,
- lambda c, e:self._save_or_update_impl(c, e),
- halt_on=lambda c:c in self)
+ self._save_or_update_impl(instance, entity_name=entity_name)
+ self._cascade_save_or_update(instance)
+
+ def _cascade_save_or_update(self, instance):
+ for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self):
+ self._save_or_update_impl(obj, mapper.entity_name)
- def delete(self, object):
+ def delete(self, instance):
"""Mark the given instance as deleted.
The delete operation occurs upon ``flush()``.
"""
- self._delete_impl(object)
- for c in list(_object_mapper(object).cascade_iterator('delete', object)):
+ self._delete_impl(instance)
+ for c, m in _cascade_iterator('delete', instance):
self._delete_impl(c, ignore_transient=True)
- def merge(self, object, entity_name=None, dont_load=False, _recursive=None):
- """Copy the state of the given `object` onto the persistent
- object with the same identifier.
+ def merge(self, instance, entity_name=None, dont_load=False, _recursive=None):
+ """Copy the state of the given `instance` onto the persistent
+ instance with the same identifier.
If there is no persistent instance currently associated with
the session, it will be loaded. Return the persistent
if _recursive is None:
_recursive = {} #TODO: this should be an IdentityDict
if entity_name is not None:
- mapper = _class_mapper(object.__class__, entity_name=entity_name)
+ mapper = _class_mapper(instance.__class__, entity_name=entity_name)
else:
- mapper = _object_mapper(object)
- if object in _recursive:
- return _recursive[object]
+ mapper = _object_mapper(instance)
+ if instance in _recursive:
+ return _recursive[instance]
- key = getattr(object, '_instance_key', None)
+ key = getattr(instance, '_instance_key', None)
if key is None:
merged = attributes.new_instance(mapper.class_)
else:
if key in self.identity_map:
merged = self.identity_map[key]
elif dont_load:
- if object._state.modified:
+ if instance._state.modified:
raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.")
merged = attributes.new_instance(mapper.class_)
else:
merged = self.get(mapper.class_, key[1])
if merged is None:
- raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
- _recursive[object] = merged
+ raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(instance))
+ _recursive[instance] = merged
for prop in mapper.iterate_properties:
- prop.merge(self, object, merged, dont_load, _recursive)
+ prop.merge(self, instance, merged, dont_load, _recursive)
if key is None:
self.save(merged, entity_name=mapper.entity_name)
elif dont_load:
return mapper.identity_key_from_instance(instance)
identity_key = classmethod(identity_key)
- def object_session(cls, obj):
+ def object_session(cls, instance):
"""return the ``Session`` to which the given object belongs."""
- return object_session(obj)
+ return object_session(instance)
object_session = classmethod(object_session)
- def _save_impl(self, obj, **kwargs):
- if hasattr(obj, '_instance_key'):
- raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(obj))
+ def _save_impl(self, instance, **kwargs):
+ if hasattr(instance, '_instance_key'):
+ raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance))
else:
# TODO: consolidate the steps here
- attributes.manage(obj)
- obj._entity_name = kwargs.get('entity_name', None)
- self._attach(obj)
- self.uow.register_new(obj)
+ attributes.manage(instance)
+ instance._entity_name = kwargs.get('entity_name', None)
+ self._attach(instance)
+ self.uow.register_new(instance)
- def _update_impl(self, obj, **kwargs):
- if obj in self and obj not in self.deleted:
+ def _update_impl(self, instance, **kwargs):
+ if instance in self and instance not in self.deleted:
return
- if not hasattr(obj, '_instance_key'):
- raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
- elif self.identity_map.get(obj._instance_key, obj) is not obj:
+ if not hasattr(instance, '_instance_key'):
+ raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
+ elif self.identity_map.get(instance._instance_key, instance) is not instance:
raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(obj), obj._instance_key))
- self._attach(obj)
+ self._attach(instance)
- def _save_or_update_impl(self, object, entity_name=None):
- key = getattr(object, '_instance_key', None)
+ def _save_or_update_impl(self, instance, entity_name=None):
+ key = getattr(instance, '_instance_key', None)
if key is None:
- self._save_impl(object, entity_name=entity_name)
+ self._save_impl(instance, entity_name=entity_name)
else:
- self._update_impl(object, entity_name=entity_name)
+ self._update_impl(instance, entity_name=entity_name)
- def _delete_impl(self, obj, ignore_transient=False):
- if obj in self and obj in self.deleted:
+ def _delete_impl(self, instance, ignore_transient=False):
+ if instance in self and instance in self.deleted:
return
- if not hasattr(obj, '_instance_key'):
+ if not hasattr(instance, '_instance_key'):
if ignore_transient:
return
else:
- raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
- if self.identity_map.get(obj._instance_key, obj) is not obj:
- raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(obj), obj._instance_key))
- self._attach(obj)
- self.uow.register_deleted(obj)
-
- def _register_persistent(self, obj):
- obj._sa_session_id = self.hash_key
- self.identity_map[obj._instance_key] = obj
- obj._state.commit_all()
-
- def _attach(self, obj):
- old_id = getattr(obj, '_sa_session_id', None)
+ raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
+ if self.identity_map.get(instance._instance_key, instance) is not instance:
+ raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key))
+ self._attach(instance)
+ self.uow.register_deleted(instance)
+
+ def _register_persistent(self, instance):
+ instance._sa_session_id = self.hash_key
+ self.identity_map[instance._instance_key] = instance
+ instance._state.commit_all()
+
+ def _attach(self, instance):
+ old_id = getattr(instance, '_sa_session_id', None)
if old_id != self.hash_key:
- if old_id is not None and old_id in _sessions and obj in _sessions[old_id]:
+ if old_id is not None and old_id in _sessions and instance in _sessions[old_id]:
raise exceptions.InvalidRequestError("Object '%s' is already attached "
"to session '%s' (this is '%s')" %
- (mapperutil.instance_str(obj), old_id, id(self)))
+ (mapperutil.instance_str(instance), old_id, id(self)))
- key = getattr(obj, '_instance_key', None)
+ key = getattr(instance, '_instance_key', None)
if key is not None:
- self.identity_map[key] = obj
- obj._sa_session_id = self.hash_key
+ self.identity_map[key] = instance
+ instance._sa_session_id = self.hash_key
- def _unattach(self, obj):
- if obj._sa_session_id == self.hash_key:
- del obj._sa_session_id
+ def _unattach(self, instance):
+ if instance._sa_session_id == self.hash_key:
+ del instance._sa_session_id
- def _validate_persistent(self, obj):
- """Validate that the given object is persistent within this
+ def _validate_persistent(self, instance):
+ """Validate that the given instance is persistent within this
``Session``.
"""
- return obj in self
+ return instance in self
- def __contains__(self, obj):
- """return True if the given object is associated with this session.
+ def __contains__(self, instance):
+ """return True if the given instance is associated with this session.
The instance may be pending or persistent within the Session for a
result of True.
"""
- return obj in self.uow.new or (hasattr(obj, '_instance_key') and self.identity_map.get(obj._instance_key) is obj)
+ return instance in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance)
def __iter__(self):
- """return an iterator of all objects which are pending or persistent within this Session."""
+ """return an iterator of all instances which are pending or persistent within this Session."""
return iter(list(self.uow.new) + self.uow.identity_map.values())
- def is_modified(self, obj, include_collections=True, passive=False):
- """return True if the given object has modified attributes.
+ def is_modified(self, instance, include_collections=True, passive=False):
+ """return True if the given instance has modified attributes.
This method retrieves a history instance for each instrumented attribute
on the instance and performs a comparison of the current value to its
not be loaded in the course of performing this test.
"""
- for attr in attributes.managed_attributes(obj.__class__):
+ for attr in attributes.managed_attributes(instance.__class__):
if not include_collections and hasattr(attr.impl, 'get_collection'):
continue
- if attr.get_history(obj).is_modified():
+ if attr.get_history(instance).is_modified():
return True
return False
dirty = property(lambda s:s.uow.locate_dirty(),
- doc="""A ``Set`` of all objects marked as 'dirty' within this ``Session``.
+ doc="""A ``Set`` of all instances marked as 'dirty' within this ``Session``.
Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
modification operations will mark an instance as 'dirty' and place it in this set,
""")
deleted = property(lambda s:s.uow.deleted,
- doc="A ``Set`` of all objects marked as 'deleted' within this ``Session``")
+ doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``")
new = property(lambda s:s.uow.new,
- doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
+ doc="A ``Set`` of all instances marked as 'new' within this ``Session``.")
-def expire_instance(obj, attribute_names):
+def expire_instance(instance, attribute_names):
"""standalone expire instance function.
installs a callable with the given instance's _state
If the list is None or blank, the entire instance is expired.
"""
- if obj._state.trigger is None:
+ if instance._state.trigger is None:
def load_attributes(instance, attribute_names):
if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- obj._state.trigger = load_attributes
+ instance._state.trigger = load_attributes
- obj._state.expire_attributes(attribute_names)
+ instance._state.expire_attributes(attribute_names)
register_attribute = unitofwork.register_attribute
-# this dictionary maps the hash key of a Session to the Session itself, and
-# acts as a Registry with which to locate Sessions. this is to enable
-# object instances to be associated with Sessions without having to attach the
-# actual Session object directly to the object instance.
_sessions = weakref.WeakValueDictionary()
-def object_session(obj):
- """Return the ``Session`` to which the given object is bound, or ``None`` if none."""
+def _cascade_iterator(cascade, instance, **kwargs):
+ mapper = _object_mapper(instance)
+ for (o, m) in mapper.cascade_iterator(cascade, instance, **kwargs):
+ yield o, m
+
+def object_session(instance):
+ """Return the ``Session`` to which the given instance is bound, or ``None`` if none."""
- hashkey = getattr(obj, '_sa_session_id', None)
+ hashkey = getattr(instance, '_sa_session_id', None)
if hashkey is not None:
sess = _sessions.get(hashkey)
- if obj in sess:
+ if instance in sess:
return sess
return None
def create_statement(instance):
params = {}
for c in param_names:
- params[c.name] = mapper.get_attr_by_column(instance, c)
+ params[c.name] = mapper._get_attr_by_column(instance, c)
return (statement, params)
def new_execute(instance, row, isnew, **flags):
def visit_bindparam(s, bindparam):
mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
if bindparam.key in bind_to_col:
- bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
+ bindparam.value = mapper._get_attr_by_column(instance, bind_to_col[bindparam.key])
return Visitor().traverse(criterion, clone=True)
def setup_loader(self, instance, options=None, path=None):
if self.use_get:
params = {}
for col, bind in self.lazybinds.iteritems():
- params[bind.key] = self.parent.get_attr_by_column(instance, col)
+ params[bind.key] = self.parent._get_attr_by_column(instance, col)
ident = []
nonnulls = False
for primary_key in self.select_mapper.primary_key:
#print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
def dest_primary_key(self):
+ # late-evaluating boolean since some syncs are created
+ # before the mapper has assembled pks
try:
return self._dest_primary_key
except AttributeError:
- self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
+ self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
return self._dest_primary_key
def execute(self, source, dest, obj, child, clearkeys):
value = None
clearkeys = True
else:
- value = self.source_mapper.get_attr_by_column(source, self.source_column)
+ value = self.source_mapper._get_attr_by_column(source, self.source_column)
if isinstance(dest, dict):
dest[self.dest_column.key] = value
else:
if logging.is_debug_enabled(self.logger):
self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
- self.dest_mapper.set_attr_by_column(dest, self.dest_column, value)
+ self.dest_mapper._set_attr_by_column(dest, self.dest_column, value)
SyncRule.logger = logging.class_logger(SyncRule)
vis.traverse(table)
sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False)
if reverse:
- sequence.reverse()
- return sequence
+ return util.reversed(sequence)
+ else:
+ return sequence
def find_tables(clause, check_columns=False, include_aliases=False):
tables = []
self._do_test(True)
assert False
except RuntimeWarning, e:
- assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name."
+ assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e)
def test_explicit_pk(self):
person_mapper = mapper(Person, person_table)
compile_mappers()
assert False
except exceptions.ArgumentError, e:
- assert str(e) == "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'"
+ assert str(e) in [
+ "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'",
+ "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'"
+ ]
+
def testcircular(self):
"""tests a many-to-many relationship from a table to itself."""
class A(object):pass
m = mapper(A, account_ids_table.join(account_stuff_table))
m.compile()
- assert m._has_pks(account_ids_table)
- assert not m._has_pks(account_stuff_table)
+ assert account_ids_table in m._pks_by_table
+ assert account_stuff_table not in m._pks_by_table
metadata.create_all(testbase.db)
try:
sess = create_session(bind=testbase.db)