From c6d01a56e168f1c70461c2684c70a2c5967c4814 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 1 Dec 2007 23:00:05 +0000 Subject: [PATCH] - several ORM attributes have been removed or made private: mapper.get_attr_by_column(), mapper.set_attr_by_column(), mapper.pks_by_table, mapper.cascade_callable(), MapperProperty.cascade_callable(), mapper.canload() - refinements to mapper PK/table column organization, session cascading, some naming convention work --- CHANGES | 5 + lib/sqlalchemy/orm/collections.py | 4 +- lib/sqlalchemy/orm/dependency.py | 10 +- lib/sqlalchemy/orm/interfaces.py | 15 +- lib/sqlalchemy/orm/mapper.py | 225 ++++++++++------------------ lib/sqlalchemy/orm/properties.py | 43 +++--- lib/sqlalchemy/orm/session.py | 238 +++++++++++++++--------------- lib/sqlalchemy/orm/strategies.py | 6 +- lib/sqlalchemy/orm/sync.py | 8 +- lib/sqlalchemy/sql/util.py | 5 +- test/orm/inheritance/basic.py | 2 +- test/orm/manytomany.py | 6 +- test/orm/mapper.py | 4 +- 13 files changed, 256 insertions(+), 315 deletions(-) diff --git a/CHANGES b/CHANGES index 1dc797d77e..6e275254f3 100644 --- a/CHANGES +++ b/CHANGES @@ -34,6 +34,11 @@ CHANGES relationship where it takes effect for all inheriting mappers. [ticket:883] + - several ORM attributes have been removed or made private: + mapper.get_attr_by_column(), mapper.set_attr_by_column(), + mapper.pks_by_table, mapper.cascade_callable(), + MapperProperty.cascade_callable(), mapper.canload() + - fixed endless loop issue when using lazy="dynamic" on both sides of a bi-directional relationship [ticket:872] diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 942b880c94..c2cd4cf09d 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -122,7 +122,7 @@ def column_mapped_collection(mapping_spec): if isinstance(mapping_spec, schema.Column): def keyfunc(value): m = object_mapper(value) - return m.get_attr_by_column(value, mapping_spec) + return m._get_attr_by_column(value, mapping_spec) else: cols = [] for c in mapping_spec: @@ -133,7 +133,7 @@ def column_mapped_collection(mapping_spec): mapping_spec = tuple(cols) def keyfunc(value): m = object_mapper(value) - return tuple([m.get_attr_by_column(value, c) for c in mapping_spec]) + return tuple([m._get_attr_by_column(value, c) for c in mapping_spec]) return lambda: MappedCollection(keyfunc) def attribute_mapped_collection(attr_name): diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 9688999169..9220c5743b 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -111,7 +111,7 @@ class DependencyProcessor(object): def _verify_canload(self, child): if not self.enable_typechecks: return - if child is not None and not self.mapper.canload(child): + if child is not None and not self.mapper._canload(child): raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (child.__class__, self.prop, self.mapper)) def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): @@ -240,7 +240,7 @@ class OneToManyDP(DependencyProcessor): uowcommit.register_object(child, isdelete=False) elif self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator('delete', child): uowcommit.register_object(c, isdelete=True) def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): @@ -294,7 +294,7 @@ class ManyToOneDP(DependencyProcessor): for child in childlist.deleted_items() + childlist.unchanged_items(): if child is not None and self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator('delete', child): uowcommit.register_object(c, isdelete=True) else: for obj in deplist: @@ -305,7 +305,7 @@ class ManyToOneDP(DependencyProcessor): for child in childlist.deleted_items(): if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator('delete', child): uowcommit.register_object(c, isdelete=True) def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): @@ -391,7 +391,7 @@ class ManyToManyDP(DependencyProcessor): for child in childlist.deleted_items(): if self.cascade.delete_orphan and self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator('delete', child): uowcommit.register_object(c, isdelete=True) def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit): diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index aa0b2dcc24..815ea8cebb 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -319,18 +319,15 @@ class MapperProperty(object): """ raise NotImplementedError() - + def cascade_iterator(self, type, object, recursive=None, halt_on=None): - """return an iterator of objects which are child objects of the given object, - as attached to the attribute corresponding to this MapperProperty.""" + """iterate through instances related to the given instance along + a particular 'cascade' path, starting with this MapperProperty. - return [] - - def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): - """run the given callable across all objects which are child objects of - the given object, as attached to the attribute corresponding to this MapperProperty.""" + see PropertyLoader for the related instance implementation. + """ - return [] + return iter([]) def get_criterion(self, query, key, value): """Return a ``WHERE`` clause suitable for this diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5673be44c1..67087c5708 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -4,15 +4,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import weakref, warnings, operator +import weakref, warnings +from itertools import chain from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy.sql import expression, visitors +from sqlalchemy.sql import expression, visitors, operators from sqlalchemy.sql import util as sqlutil from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter from sqlalchemy.orm import sync, attributes from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator -deferred_load = None __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry'] @@ -22,7 +22,7 @@ mapper_registry = weakref.WeakKeyDictionary() # a list of MapperExtensions that will be installed in all mappers by default global_extensions = [] -# a constant returned by get_attr_by_column to indicate +# a constant returned by _get_attr_by_column to indicate # this mapper is not handling an attribute for a particular # column NO_ATTRIBUTE = object() @@ -152,6 +152,7 @@ class Mapper(object): self._compile_inheritance() self._compile_tables() self._compile_properties() + self._compile_pks() self._compile_selectable() self.__log("constructed") @@ -376,12 +377,6 @@ class Mapper(object): self.polymorphic_map[key] = class_or_mapper def _compile_tables(self): - """After the inheritance relationships have been reconciled, - set up some more table-based instance variables and determine - the *primary key* columns for all tables represented by this - ``Mapper``. - """ - # summary of the various Selectable units: # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table) # local_table - the Selectable that was passed to this Mapper's constructor, if any @@ -401,27 +396,25 @@ class Mapper(object): if not self.tables: raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) - # TODO: move the "figure pks" step down into compile_properties; after - # all columns have been mapped, assemble PK columns and their - # proxied parents into the pks_by_table collection, then get rid - # of the _has_pks method - - # determine primary key columns - self.pks_by_table = {} + def _compile_pks(self): - # go through all of our represented tables - # and assemble primary key columns - for t in self.tables + [self.mapped_table]: + self._pks_by_table = {} + self._cols_by_table = {} + + all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]])) + pk_cols = util.Set([c for c in all_cols if c.primary_key]) + + for t in util.Set(self.tables + [self.mapped_table]): self._all_tables.add(t) - if t not in self.pks_by_table: - self.pks_by_table[t] = util.OrderedSet() - self.pks_by_table[t].update(t.primary_key) - - if self.primary_key_argument is not None: + if t.primary_key and pk_cols.issuperset(t.primary_key): + self._pks_by_table[t] = util.Set(t.primary_key).intersection(pk_cols) + self._cols_by_table[t] = util.Set(t.c).intersection(all_cols) + + if self.primary_key_argument: for k in self.primary_key_argument: - self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k) + self._pks_by_table.setdefault(k.table, util.Set()).add(k) - if len(self.pks_by_table[self.mapped_table]) == 0: + if len(self._pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) if self.inherits is not None and not self.concrete and not self.primary_key_argument: @@ -437,7 +430,7 @@ class Mapper(object): primary_key = expression.ColumnSet() - for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): c = self.mapped_table.corresponding_column(col, raiseerr=False) if c is None: for cc in self._equivalent_columns[col]: @@ -474,7 +467,10 @@ class Mapper(object): self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) - + + # create a "get clause" based on the primary key. this is used + # by query.get() and many-to-one lazyloads to load this item + # by primary key. _get_clause = sql.and_() _get_params = {} for primary_key in self.primary_key: @@ -510,7 +506,7 @@ class Mapper(object): result = {} def visit_binary(binary): - if binary.operator == operator.eq: + if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: @@ -533,13 +529,17 @@ class Mapper(object): return recursive.add(col) for fk in col.foreign_keys: - result.setdefault(fk.column, util.Set()).add(equiv) + if fk.column not in result: + result[fk.column] = util.Set() + result[fk.column].add(equiv) equivs(fk.column, recursive, col) - for column in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): for col in column.proxy_set: if not col.foreign_keys: - result.setdefault(col, util.Set()).add(col) + if col not in result: + result[col] = util.Set() + result[col].add(col) else: equivs(col, util.Set(), col) @@ -571,11 +571,6 @@ class Mapper(object): return getattr(getattr(cls, clskey), key) def _compile_properties(self): - """Inspect the properties dictionary sent to the Mapper's - constructor as well as the mapped_table, and create - ``MapperProperty`` objects corresponding to each mapped column - and relation. - """ # object attribute names mapped to MapperProperty objects self.__props = util.OrderedDict() @@ -637,9 +632,6 @@ class Mapper(object): # TODO: the "property already exists" case is still not well defined here. # assuming single-column, etc. - if column in self.primary_key and prop.columns[-1] in self.primary_key: - warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), key))) - if prop.parent is not self: # existing ColumnProperty from an inheriting mapper. # make a copy and append our column to it @@ -896,38 +888,27 @@ class Mapper(object): instance. """ - return [self.get_attr_by_column(instance, column) for column in self.primary_key] + return [self._get_attr_by_column(instance, column) for column in self.primary_key] - def canload(self, instance): + def _canload(self, instance): """return true if this mapper is capable of loading the given instance""" if self.polymorphic_on is not None: return isinstance(instance, self.class_) else: return instance.__class__ is self.class_ - def _getpropbycolumn(self, column, raiseerror=True): + def _get_attr_by_column(self, obj, column): + """Return an instance attribute using a Column as the key.""" try: - return self._columntoproperty[column] + return self._columntoproperty[column].getattr(obj, column) except KeyError: - try: - prop = self.__props[column.key] - if not raiseerror: - return None + prop = self.__props.get(column.key, None) + if prop: raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) - except KeyError: - if not raiseerror: - return None + else: raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) - - def get_attr_by_column(self, obj, column, raiseerror=True): - """Return an instance attribute using a Column as the key.""" - - prop = self._getpropbycolumn(column, raiseerror) - if prop is None: - return NO_ATTRIBUTE - return prop.getattr(obj, column) - - def set_attr_by_column(self, obj, column, value): + + def _set_attr_by_column(self, obj, column, value): """Set the value of an instance attribute using a Column as the key.""" self._columntoproperty[column].setattr(obj, value, column) @@ -996,18 +977,18 @@ class Mapper(object): table_to_mapper = {} for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: - table_to_mapper.setdefault(t, mapper) + table_to_mapper[t] = mapper - for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False): + for table in sqlutil.sort_tables(table_to_mapper.keys()): # two lists to store parameters for each table/object pair located insert = [] update = [] for obj, connection in tups: mapper = object_mapper(obj) - if table not in mapper.tables or not mapper._has_pks(table): + if table not in mapper._pks_by_table: continue - pks = mapper.pks_by_table[table] + pks = mapper._pks_by_table[table] instance_key = mapper.identity_key_from_instance(obj) if self.__should_log_debug: @@ -1019,11 +1000,11 @@ class Mapper(object): hasdata = False if isinsert: - for col in table.columns: + for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col.key] = 1 elif col in pks: - value = mapper.get_attr_by_column(obj, col) + value = mapper._get_attr_by_column(obj, col) if value is not None: params[col.key] = value elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): @@ -1033,9 +1014,7 @@ class Mapper(object): if col.default is None or value is not None: params[col.key] = value else: - value = mapper.get_attr_by_column(obj, col, False) - if value is NO_ATTRIBUTE: - continue + value = mapper._get_attr_by_column(obj, col) if col.default is None or value is not None: if isinstance(value, sql.ClauseElement): value_params[col] = value @@ -1043,24 +1022,22 @@ class Mapper(object): params[col.key] = value insert.append((obj, params, mapper, connection, value_params)) else: - for col in table.columns: + for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: - params[col._label] = mapper.get_attr_by_column(obj, col) + params[col._label] = mapper._get_attr_by_column(obj, col) params[col.key] = params[col._label] + 1 for prop in mapper._columntoproperty.values(): history = attributes.get_history(obj, prop.key, passive=True) if history and history.added_items(): hasdata = True elif col in pks: - params[col._label] = mapper.get_attr_by_column(obj, col) + params[col._label] = mapper._get_attr_by_column(obj, col) elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): pass else: if post_update_cols is not None and col not in post_update_cols: continue - prop = mapper._getpropbycolumn(col, False) - if prop is None: - continue + prop = mapper._columntoproperty[col] history = attributes.get_history(obj, prop.key, passive=True) if history: a = history.added_items() @@ -1076,14 +1053,14 @@ class Mapper(object): if update: mapper = table_to_mapper[table] clause = sql.and_() - for col in mapper.pks_by_table[table]: + for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True)) if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True)) statement = table.update(clause) rows = 0 supports_sane_rowcount = True - pks = mapper.pks_by_table[table] + pks = mapper._pks_by_table[table] def comparator(a, b): for col in pks: x = cmp(a[1][col._label],b[1][col._label]) @@ -1115,23 +1092,22 @@ class Mapper(object): if primary_key is not None: i = 0 - for col in mapper.pks_by_table[table]: - if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i: - mapper.set_attr_by_column(obj, col, primary_key[i]) + for col in mapper._pks_by_table[table]: + if mapper._get_attr_by_column(obj, col) is None and len(primary_key) > i: + mapper._set_attr_by_column(obj, col, primary_key[i]) i+=1 mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this fires off more than needed, try to organize syncrules # per table - mappers = list(mapper.iterate_to_root()) - mappers.reverse() - for m in mappers: + for m in util.reversed(list(mapper.iterate_to_root())): if m._synchronizer is not None: m._synchronizer.execute(obj, obj) # testlib.pragma exempt:__hash__ inserted_objects.add((id(obj), obj, connection)) + if not postupdate: for id_, obj, connection in inserted_objects: for mapper in object_mapper(obj).iterate_to_root(): @@ -1141,7 +1117,7 @@ class Mapper(object): for mapper in object_mapper(obj).iterate_to_root(): if 'after_update' in mapper.extension.methods: mapper.extension.after_update(mapper, connection, obj) - + def _postfetch(self, connection, table, obj, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated values on an instance. For columns which are marked as being generated @@ -1152,20 +1128,15 @@ class Mapper(object): postfetch_cols = resultproxy.postfetch_cols().union(util.Set(value_params.keys())) deferred_props = [] - for c in table.c: + for c in self._cols_by_table[table]: if c in postfetch_cols and (not c.key in params or c in value_params): - prop = self._getpropbycolumn(c, raiseerror=False) - if prop is None: - continue + prop = self._columntoproperty[c] deferred_props.append(prop.key) continue if c.primary_key or not c.key in params: continue - v = self.get_attr_by_column(obj, c, False) - if v is NO_ATTRIBUTE: - continue - elif v != params[c.key]: - self.set_attr_by_column(obj, c, params[c.key]) + if self._get_attr_by_column(obj, c) != params[c.key]: + self._set_attr_by_column(obj, c, params[c.key]) if deferred_props: expire_instance(obj, deferred_props) @@ -1196,13 +1167,13 @@ class Mapper(object): table_to_mapper = {} for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: - table_to_mapper.setdefault(t, mapper) + table_to_mapper[t] = mapper for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True): delete = {} for (obj, connection) in tups: mapper = object_mapper(obj) - if table not in mapper.tables or not mapper._has_pks(table): + if table not in mapper._pks_by_table: continue params = {} @@ -1210,23 +1181,23 @@ class Mapper(object): continue else: delete.setdefault(connection, []).append(params) - for col in mapper.pks_by_table[table]: - params[col.key] = mapper.get_attr_by_column(obj, col) + for col in mapper._pks_by_table[table]: + params[col.key] = mapper._get_attr_by_column(obj, col) if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col) + params[mapper.version_id_col.key] = mapper._get_attr_by_column(obj, mapper.version_id_col) # testlib.pragma exempt:__hash__ deleted_objects.add((id(obj), obj, connection)) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): - for col in mapper.pks_by_table[table]: + for col in mapper._pks_by_table[table]: x = cmp(a[col.key],b[col.key]) if x != 0: return x return 0 del_objects.sort(comparator) clause = sql.and_() - for col in mapper.pks_by_table[table]: + for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) @@ -1240,17 +1211,6 @@ class Mapper(object): if 'after_delete' in mapper.extension.methods: mapper.extension.after_delete(mapper, connection, obj) - def _has_pks(self, table): - # TODO: determine this beforehand - if self.pks_by_table.get(table, None): - for k in self.pks_by_table[table]: - if k not in self._columntoproperty: - return False - else: - return True - else: - return False - def register_dependencies(self, uowcommit, *args, **kwargs): """Register ``DependencyProcessor`` instances with a ``unitofwork.UOWTransaction``. @@ -1263,8 +1223,8 @@ class Mapper(object): prop.register_dependencies(uowcommit, *args, **kwargs) def cascade_iterator(self, type, object, recursive=None, halt_on=None): - """Iterate each element in an object graph, for all relations - taht meet the given cascade rule. + """Iterate each element and its mapper in an object graph, + for all relations that meet the given cascade rule. type The name of the cascade rule (i.e. save-update, delete, @@ -1282,33 +1242,8 @@ class Mapper(object): if recursive is None: recursive=util.IdentitySet() for prop in self.__props.values(): - for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on): - yield c - - def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): - """Execute a callable for each element in an object graph, for - all relations that meet the given cascade rule. - - type - The name of the cascade rule (i.e. save-update, delete, etc.) - - object - The lead object instance. child items will be processed per - the relations defined for this object's mapper. - - callable\_ - The callable function. - - recursive - Used by the function for internal context during recursive - calls, leave as None. - - """ - - if recursive is None: - recursive=util.IdentitySet() - for prop in self.__props.values(): - prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on) + for (c, m) in prop.cascade_iterator(type, object, recursive, halt_on=halt_on): + yield (c, m) def get_select_mapper(self): """Return the mapper used for issuing selects. @@ -1365,8 +1300,8 @@ class Mapper(object): isnew = False - if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: - raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) + if context.version_check and self.version_id_col is not None and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: + raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) if context.populate_existing or self.always_refresh or instance._state.trigger is not None: instance._state.trigger = None @@ -1541,7 +1476,7 @@ class Mapper(object): params = {} for c in param_names: - params[c.name] = self.get_attr_by_column(instance, c) + params[c.name] = self._get_attr_by_column(instance, c) row = selectcontext.session.connection(self).execute(statement, params).fetchone() self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 2c50ec92fc..806c91a2b7 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -12,13 +12,13 @@ to handle flush-time dependency sorting and processing. """ from sqlalchemy import sql, schema, util, exceptions, logging -from sqlalchemy.sql import util as sql_util, visitors +from sqlalchemy.sql import util as sql_util, visitors, operators from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -import operator from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty from sqlalchemy.exceptions import ArgumentError +import warnings __all__ = ['ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'PropertyLoader', 'BackRef'] @@ -48,7 +48,12 @@ class ColumnProperty(StrategizedProperty): return strategies.DeferredColumnLoader(self) else: return strategies.ColumnLoader(self) - + + def do_init(self): + super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns): + warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self.parent), str(self.columns[1]), str(self.columns[0]), self.key))) + def copy(self): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) @@ -75,10 +80,8 @@ class ColumnProperty(StrategizedProperty): col = self.prop.columns[0] return op(col._bind_param(other), col) - ColumnProperty.logger = logging.class_logger(ColumnProperty) - class CompositeProperty(ColumnProperty): """subclasses ColumnProperty to provide composite type support.""" @@ -86,6 +89,10 @@ class CompositeProperty(ColumnProperty): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self) + + def do_init(self): + super(ColumnProperty, self).do_init() + # TODO: similar PK check as ColumnProperty does ? def copy(self): return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) @@ -283,7 +290,7 @@ class PropertyLoader(StrategizedProperty): return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) def compare(self, op, value, value_is_parent=False): - if op == operator.eq: + if op == operators.eq: if value is None: return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin) else: @@ -347,23 +354,9 @@ class PropertyLoader(StrategizedProperty): if not isinstance(c, self.mapper.class_): raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) recursive.add(c) - yield c - for c2 in mapper.cascade_iterator(type, c, recursive): - yield c2 - - def cascade_callable(self, type, object, callable_, recursive, halt_on=None): - if not type in self.cascade: - return - - mapper = self.mapper.primary_mapper() - passive = type != 'delete' or self.passive_deletes - for c in attributes.get_as_list(object, self.key, passive=passive): - if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): - if not isinstance(c, self.mapper.class_): - raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) - recursive.add(c) - callable_(c, mapper.entity_name) - mapper.cascade_callable(type, c, callable_, recursive) + yield (c, mapper) + for (c2, m) in mapper.cascade_iterator(type, c, recursive): + yield (c2, m) def _get_target_class(self): """Return the target class of the relation, even if the @@ -464,7 +457,7 @@ class PropertyLoader(StrategizedProperty): if self.foreign_keys: self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return if binary.left in self.foreign_keys: self._opposite_side.add(binary.right) @@ -477,7 +470,7 @@ class PropertyLoader(StrategizedProperty): self.foreign_keys = util.Set() self._opposite_side = util.Set() def visit_binary(binary): - if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): + if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return # this check is for when the user put the "view_only" flag on and has tables that have nothing diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 7097273f59..28ef39aba4 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1,4 +1,4 @@ -# objectstore.py +# session.py # Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -103,10 +103,10 @@ class SessionExtension(object): Note that this may not be per-flush if a longer running transaction is ongoing.""" - def before_flush(self, session, flush_context, objects): + def before_flush(self, session, flush_context, instances): """execute before flush process has started. - 'objects' is an optional list of objects which were passed to the ``flush()`` + 'instances' is an optional list of objects which were passed to the ``flush()`` method. """ @@ -719,7 +719,7 @@ class Session(object): entity_name = kwargs.pop('entity_name', None) return self.query(class_, entity_name=entity_name).load(ident, **kwargs) - def refresh(self, obj, attribute_names=None): + def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. When called, a query will be issued @@ -738,12 +738,12 @@ class Session(object): refreshed. """ - self._validate_persistent(obj) + self._validate_persistent(instance) - if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None: - raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj)) + if self.query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None: + raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) - def expire(self, obj, attribute_names=None): + def expire(self, instance, attribute_names=None): """Expire the attributes on the given instance. The instance's attributes are instrumented such that @@ -764,11 +764,16 @@ class Session(object): """ if attribute_names: - self._validate_persistent(obj) - expire_instance(obj, attribute_names=attribute_names) + self._validate_persistent(instance) + expire_instance(instance, attribute_names=attribute_names) else: - for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)): - self._validate_persistent(obj) + # pre-fetch the full cascade since the expire is going to + # remove associations + cascaded = list(_cascade_iterator('refresh-expire', instance)) + self._validate_persistent(instance) + expire_instance(instance, None) + for (c, m) in cascaded: + self._validate_persistent(c) expire_instance(c, None) def prune(self): @@ -784,20 +789,20 @@ class Session(object): return self.uow.prune_identity_map() - def expunge(self, object): - """Remove the given `object` from this ``Session``. + def expunge(self, instance): + """Remove the given `instance` from this ``Session``. - This will free all internal references to the object. + This will free all internal references to the instance. Cascading will be applied according to the *expunge* cascade rule. """ - self._validate_persistent(object) - for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)): + self._validate_persistent(instance) + for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)): if c in self: self.uow._remove_deleted(c) self._unattach(c) - def save(self, object, entity_name=None): + def save(self, instance, entity_name=None): """Add a transient (unsaved) instance to this ``Session``. This operation cascades the `save_or_update` method to @@ -808,12 +813,10 @@ class Session(object): specific ``Mapper`` used to handle this instance. """ - self._save_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, - lambda c, e:self._save_or_update_impl(c, e), - halt_on=lambda c:c in self) + self._save_impl(instance, entity_name=entity_name) + self._cascade_save_or_update(instance) - def update(self, object, entity_name=None): + def update(self, instance, entity_name=None): """Bring the given detached (saved) instance into this ``Session``. @@ -826,37 +829,37 @@ class Session(object): ``cascade="save-update"``. """ - self._update_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, - lambda c, e:self._save_or_update_impl(c, e), - halt_on=lambda c:c in self) + self._update_impl(instance, entity_name=entity_name) + self._cascade_save_or_update(instance) - def save_or_update(self, object, entity_name=None): - """Save or update the given object into this ``Session``. + def save_or_update(self, instance, entity_name=None): + """Save or update the given instance into this ``Session``. The presence of an `_instance_key` attribute on the instance determines whether to ``save()`` or ``update()`` the instance. """ - self._save_or_update_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, - lambda c, e:self._save_or_update_impl(c, e), - halt_on=lambda c:c in self) + self._save_or_update_impl(instance, entity_name=entity_name) + self._cascade_save_or_update(instance) + + def _cascade_save_or_update(self, instance): + for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self): + self._save_or_update_impl(obj, mapper.entity_name) - def delete(self, object): + def delete(self, instance): """Mark the given instance as deleted. The delete operation occurs upon ``flush()``. """ - self._delete_impl(object) - for c in list(_object_mapper(object).cascade_iterator('delete', object)): + self._delete_impl(instance) + for c, m in _cascade_iterator('delete', instance): self._delete_impl(c, ignore_transient=True) - def merge(self, object, entity_name=None, dont_load=False, _recursive=None): - """Copy the state of the given `object` onto the persistent - object with the same identifier. + def merge(self, instance, entity_name=None, dont_load=False, _recursive=None): + """Copy the state of the given `instance` onto the persistent + instance with the same identifier. If there is no persistent instance currently associated with the session, it will be loaded. Return the persistent @@ -871,20 +874,20 @@ class Session(object): if _recursive is None: _recursive = {} #TODO: this should be an IdentityDict if entity_name is not None: - mapper = _class_mapper(object.__class__, entity_name=entity_name) + mapper = _class_mapper(instance.__class__, entity_name=entity_name) else: - mapper = _object_mapper(object) - if object in _recursive: - return _recursive[object] + mapper = _object_mapper(instance) + if instance in _recursive: + return _recursive[instance] - key = getattr(object, '_instance_key', None) + key = getattr(instance, '_instance_key', None) if key is None: merged = attributes.new_instance(mapper.class_) else: if key in self.identity_map: merged = self.identity_map[key] elif dont_load: - if object._state.modified: + if instance._state.modified: raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.") merged = attributes.new_instance(mapper.class_) @@ -894,10 +897,10 @@ class Session(object): else: merged = self.get(mapper.class_, key[1]) if merged is None: - raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object)) - _recursive[object] = merged + raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(instance)) + _recursive[instance] = merged for prop in mapper.iterate_properties: - prop.merge(self, object, merged, dont_load, _recursive) + prop.merge(self, instance, merged, dont_load, _recursive) if key is None: self.save(merged, entity_name=mapper.entity_name) elif dont_load: @@ -968,96 +971,96 @@ class Session(object): return mapper.identity_key_from_instance(instance) identity_key = classmethod(identity_key) - def object_session(cls, obj): + def object_session(cls, instance): """return the ``Session`` to which the given object belongs.""" - return object_session(obj) + return object_session(instance) object_session = classmethod(object_session) - def _save_impl(self, obj, **kwargs): - if hasattr(obj, '_instance_key'): - raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(obj)) + def _save_impl(self, instance, **kwargs): + if hasattr(instance, '_instance_key'): + raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance)) else: # TODO: consolidate the steps here - attributes.manage(obj) - obj._entity_name = kwargs.get('entity_name', None) - self._attach(obj) - self.uow.register_new(obj) + attributes.manage(instance) + instance._entity_name = kwargs.get('entity_name', None) + self._attach(instance) + self.uow.register_new(instance) - def _update_impl(self, obj, **kwargs): - if obj in self and obj not in self.deleted: + def _update_impl(self, instance, **kwargs): + if instance in self and instance not in self.deleted: return - if not hasattr(obj, '_instance_key'): - raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj)) - elif self.identity_map.get(obj._instance_key, obj) is not obj: + if not hasattr(instance, '_instance_key'): + raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance)) + elif self.identity_map.get(instance._instance_key, instance) is not instance: raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(obj), obj._instance_key)) - self._attach(obj) + self._attach(instance) - def _save_or_update_impl(self, object, entity_name=None): - key = getattr(object, '_instance_key', None) + def _save_or_update_impl(self, instance, entity_name=None): + key = getattr(instance, '_instance_key', None) if key is None: - self._save_impl(object, entity_name=entity_name) + self._save_impl(instance, entity_name=entity_name) else: - self._update_impl(object, entity_name=entity_name) + self._update_impl(instance, entity_name=entity_name) - def _delete_impl(self, obj, ignore_transient=False): - if obj in self and obj in self.deleted: + def _delete_impl(self, instance, ignore_transient=False): + if instance in self and instance in self.deleted: return - if not hasattr(obj, '_instance_key'): + if not hasattr(instance, '_instance_key'): if ignore_transient: return else: - raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj)) - if self.identity_map.get(obj._instance_key, obj) is not obj: - raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(obj), obj._instance_key)) - self._attach(obj) - self.uow.register_deleted(obj) - - def _register_persistent(self, obj): - obj._sa_session_id = self.hash_key - self.identity_map[obj._instance_key] = obj - obj._state.commit_all() - - def _attach(self, obj): - old_id = getattr(obj, '_sa_session_id', None) + raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance)) + if self.identity_map.get(instance._instance_key, instance) is not instance: + raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key)) + self._attach(instance) + self.uow.register_deleted(instance) + + def _register_persistent(self, instance): + instance._sa_session_id = self.hash_key + self.identity_map[instance._instance_key] = instance + instance._state.commit_all() + + def _attach(self, instance): + old_id = getattr(instance, '_sa_session_id', None) if old_id != self.hash_key: - if old_id is not None and old_id in _sessions and obj in _sessions[old_id]: + if old_id is not None and old_id in _sessions and instance in _sessions[old_id]: raise exceptions.InvalidRequestError("Object '%s' is already attached " "to session '%s' (this is '%s')" % - (mapperutil.instance_str(obj), old_id, id(self))) + (mapperutil.instance_str(instance), old_id, id(self))) - key = getattr(obj, '_instance_key', None) + key = getattr(instance, '_instance_key', None) if key is not None: - self.identity_map[key] = obj - obj._sa_session_id = self.hash_key + self.identity_map[key] = instance + instance._sa_session_id = self.hash_key - def _unattach(self, obj): - if obj._sa_session_id == self.hash_key: - del obj._sa_session_id + def _unattach(self, instance): + if instance._sa_session_id == self.hash_key: + del instance._sa_session_id - def _validate_persistent(self, obj): - """Validate that the given object is persistent within this + def _validate_persistent(self, instance): + """Validate that the given instance is persistent within this ``Session``. """ - return obj in self + return instance in self - def __contains__(self, obj): - """return True if the given object is associated with this session. + def __contains__(self, instance): + """return True if the given instance is associated with this session. The instance may be pending or persistent within the Session for a result of True. """ - return obj in self.uow.new or (hasattr(obj, '_instance_key') and self.identity_map.get(obj._instance_key) is obj) + return instance in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance) def __iter__(self): - """return an iterator of all objects which are pending or persistent within this Session.""" + """return an iterator of all instances which are pending or persistent within this Session.""" return iter(list(self.uow.new) + self.uow.identity_map.values()) - def is_modified(self, obj, include_collections=True, passive=False): - """return True if the given object has modified attributes. + def is_modified(self, instance, include_collections=True, passive=False): + """return True if the given instance has modified attributes. This method retrieves a history instance for each instrumented attribute on the instance and performs a comparison of the current value to its @@ -1073,15 +1076,15 @@ class Session(object): not be loaded in the course of performing this test. """ - for attr in attributes.managed_attributes(obj.__class__): + for attr in attributes.managed_attributes(instance.__class__): if not include_collections and hasattr(attr.impl, 'get_collection'): continue - if attr.get_history(obj).is_modified(): + if attr.get_history(instance).is_modified(): return True return False dirty = property(lambda s:s.uow.locate_dirty(), - doc="""A ``Set`` of all objects marked as 'dirty' within this ``Session``. + doc="""A ``Set`` of all instances marked as 'dirty' within this ``Session``. Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection modification operations will mark an instance as 'dirty' and place it in this set, @@ -1095,12 +1098,12 @@ class Session(object): """) deleted = property(lambda s:s.uow.deleted, - doc="A ``Set`` of all objects marked as 'deleted' within this ``Session``") + doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``") new = property(lambda s:s.uow.new, - doc="A ``Set`` of all objects marked as 'new' within this ``Session``.") + doc="A ``Set`` of all instances marked as 'new' within this ``Session``.") -def expire_instance(obj, attribute_names): +def expire_instance(instance, attribute_names): """standalone expire instance function. installs a callable with the given instance's _state @@ -1110,29 +1113,30 @@ def expire_instance(obj, attribute_names): If the list is None or blank, the entire instance is expired. """ - if obj._state.trigger is None: + if instance._state.trigger is None: def load_attributes(instance, attribute_names): if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) - obj._state.trigger = load_attributes + instance._state.trigger = load_attributes - obj._state.expire_attributes(attribute_names) + instance._state.expire_attributes(attribute_names) register_attribute = unitofwork.register_attribute -# this dictionary maps the hash key of a Session to the Session itself, and -# acts as a Registry with which to locate Sessions. this is to enable -# object instances to be associated with Sessions without having to attach the -# actual Session object directly to the object instance. _sessions = weakref.WeakValueDictionary() -def object_session(obj): - """Return the ``Session`` to which the given object is bound, or ``None`` if none.""" +def _cascade_iterator(cascade, instance, **kwargs): + mapper = _object_mapper(instance) + for (o, m) in mapper.cascade_iterator(cascade, instance, **kwargs): + yield o, m + +def object_session(instance): + """Return the ``Session`` to which the given instance is bound, or ``None`` if none.""" - hashkey = getattr(obj, '_sa_session_id', None) + hashkey = getattr(instance, '_sa_session_id', None) if hashkey is not None: sess = _sessions.get(hashkey) - if obj in sess: + if instance in sess: return sess return None diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 096a42bb71..3c647ac604 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -109,7 +109,7 @@ class ColumnLoader(LoaderStrategy): def create_statement(instance): params = {} for c in param_names: - params[c.name] = mapper.get_attr_by_column(instance, c) + params[c.name] = mapper._get_attr_by_column(instance, c) return (statement, params) def new_execute(instance, row, isnew, **flags): @@ -301,7 +301,7 @@ class LazyLoader(AbstractRelationLoader): def visit_bindparam(s, bindparam): mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent if bindparam.key in bind_to_col: - bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key]) + bindparam.value = mapper._get_attr_by_column(instance, bind_to_col[bindparam.key]) return Visitor().traverse(criterion, clone=True) def setup_loader(self, instance, options=None, path=None): @@ -338,7 +338,7 @@ class LazyLoader(AbstractRelationLoader): if self.use_get: params = {} for col, bind in self.lazybinds.iteritems(): - params[bind.key] = self.parent.get_attr_by_column(instance, col) + params[bind.key] = self.parent._get_attr_by_column(instance, col) ident = [] nonnulls = False for primary_key in self.select_mapper.primary_key: diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 9575aa958f..8132c7e4a5 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -115,10 +115,12 @@ class SyncRule(object): #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper def dest_primary_key(self): + # late-evaluating boolean since some syncs are created + # before the mapper has assembled pks try: return self._dest_primary_key except AttributeError: - self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks + self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks return self._dest_primary_key def execute(self, source, dest, obj, child, clearkeys): @@ -131,7 +133,7 @@ class SyncRule(object): value = None clearkeys = True else: - value = self.source_mapper.get_attr_by_column(source, self.source_column) + value = self.source_mapper._get_attr_by_column(source, self.source_column) if isinstance(dest, dict): dest[self.dest_column.key] = value else: @@ -140,7 +142,7 @@ class SyncRule(object): if logging.is_debug_enabled(self.logger): self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value)) - self.dest_mapper.set_attr_by_column(dest, self.dest_column, value) + self.dest_mapper._set_attr_by_column(dest, self.dest_column, value) SyncRule.logger = logging.class_logger(SyncRule) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d3e89d57e6..1cf0cb1b05 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -18,8 +18,9 @@ def sort_tables(tables, reverse=False): vis.traverse(table) sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False) if reverse: - sequence.reverse() - return sequence + return util.reversed(sequence) + else: + return sequence def find_tables(clause, check_columns=False, include_aliases=False): tables = [] diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 0301d01c99..5affaa238f 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -489,7 +489,7 @@ class DistinctPKTest(ORMTest): self._do_test(True) assert False except RuntimeWarning, e: - assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name." + assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e) def test_explicit_pk(self): person_mapper = mapper(Person, person_table) diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py index 5608ae67de..a94e9bbc44 100644 --- a/test/orm/manytomany.py +++ b/test/orm/manytomany.py @@ -79,7 +79,11 @@ class M2MTest(ORMTest): compile_mappers() assert False except exceptions.ArgumentError, e: - assert str(e) == "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'" + assert str(e) in [ + "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'", + "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'" + ] + def testcircular(self): """tests a many-to-many relationship from a table to itself.""" diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 9cd07b8fee..f0d5536302 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -269,8 +269,8 @@ class MapperTest(MapperSuperTest): class A(object):pass m = mapper(A, account_ids_table.join(account_stuff_table)) m.compile() - assert m._has_pks(account_ids_table) - assert not m._has_pks(account_stuff_table) + assert account_ids_table in m._pks_by_table + assert account_stuff_table not in m._pks_by_table metadata.create_all(testbase.db) try: sess = create_session(bind=testbase.db) -- 2.47.2