From 2bef6699d35b80bf1e329878f8f6a46134b9dc3d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 19 Oct 2006 07:02:04 +0000 Subject: [PATCH] progress on [ticket:329] --- CHANGES | 2 +- lib/sqlalchemy/orm/__init__.py | 5 +- lib/sqlalchemy/orm/interfaces.py | 47 ++- lib/sqlalchemy/orm/mapper.py | 124 +++---- lib/sqlalchemy/orm/query.py | 82 ++++- lib/sqlalchemy/orm/session.py | 89 ++--- lib/sqlalchemy/orm/strategies.py | 28 +- lib/sqlalchemy/orm/unitofwork.py | 413 +++++++++++------------- lib/sqlalchemy/sql_util.py | 7 +- lib/sqlalchemy/{orm => }/topological.py | 2 +- test/orm/mapper.py | 14 + 11 files changed, 434 insertions(+), 379 deletions(-) rename lib/sqlalchemy/{orm => }/topological.py (98%) diff --git a/CHANGES b/CHANGES index 9975170bd2..23e3a78d83 100644 --- a/CHANGES +++ b/CHANGES @@ -36,7 +36,7 @@ methods, methods that are no longer needed. slightly more constrained useage, greater emphasis on explicitness - the "primary_key" attribute of Table and other selectables becomes - a setlike ColumnCollection object; is no longer ordered or numerically + a setlike ColumnCollection object; is ordered but not numerically indexed. a comparison clause between two pks that are derived from the same underlying tables (i.e. such as two Alias objects) can be generated via table1.primary_key==table2.primary_key diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 2d9a4e8451..cea363116b 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -19,7 +19,7 @@ from session import Session as create_session __all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'mapper', 'clear_mappers', 'clear_mapper', 'sql', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query', - 'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'EXT_PASS' + 'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'contains_eager', 'EXT_PASS' ] def relation(*args, **kwargs): @@ -75,6 +75,9 @@ def noload(name): into a non-load.""" return strategies.EagerLazyOption(name, lazy=None) +def contains_eager(key, decorator=None): + return strategies.RowDecorateOption(key, decorator=decorator) + def defer(name): """returns a MapperOption that will convert the column property of the given name into a deferred load. Used with mapper.options()""" diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 9a6b404a02..872164d325 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -72,6 +72,7 @@ class StrategizedProperty(MapperProperty): self._all_strategies[cls] = strategy return strategy def setup(self, querycontext, **kwargs): + print "SP SETUP, KEY", self.key, " STRAT IS ", self._get_context_strategy(querycontext) self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs) def execute(self, selectcontext, instance, row, identitykey, isnew): self._get_context_strategy(selectcontext).process_row(selectcontext, instance, row, identitykey, isnew) @@ -92,32 +93,50 @@ class OperationContext(object): self.attributes = {} self.recursion_stack = util.Set() for opt in options: - opt.process_context(self) + self.accept_option(opt) + def accept_option(self, opt): + pass class MapperOption(object): """describes a modification to an OperationContext.""" - def process_context(self, context): + def process_query_context(self, context): pass - -class StrategizedOption(MapperOption): - """a MapperOption that affects which LoaderStrategy will be used for an operation - by a StrategizedProperty.""" + def process_selection_context(self, context): + pass + +class PropertyOption(MapperOption): + """a MapperOption that is applied to a property off the mapper + or one of its child mappers, identified by a dot-separated key.""" def __init__(self, key): self.key = key - def get_strategy_class(self): - raise NotImplementedError() - def process_context(self, context): + def process_query_property(self, context, property): + pass + def process_selection_property(self, context, property): + pass + def process_query_context(self, context): + self.process_query_property(context, self._get_property(context)) + def process_selection_context(self, context): + self.process_selection_property(context, self._get_property(context)) + def _get_property(self, context): try: - key = self.__key + prop = self.__prop except AttributeError: mapper = context.mapper for token in self.key.split('.'): prop = mapper.props[token] mapper = getattr(prop, 'mapper', None) - self.__key = (LoaderStrategy, prop) - key = self.__key - context.attributes[key] = self.get_strategy_class() - + self.__prop = prop + return prop + +class StrategizedOption(PropertyOption): + """a MapperOption that affects which LoaderStrategy will be used for an operation + by a StrategizedProperty.""" + def process_query_property(self, context, property): + print "HI " + self.key + " " + property.key + context.attributes[(LoaderStrategy, property)] = self.get_strategy_class() + def get_strategy_class(self): + raise NotImplementedError() + class LoaderStrategy(object): """describes the loading behavior of a StrategizedProperty object. The LoaderStrategy diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e1491a4a84..87c276368e 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -13,7 +13,7 @@ import query as querylib import session as sessionlib import weakref -__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS', 'SelectionContext'] +__all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS'] # a dictionary mapping classes to their primary mappers mapper_registry = weakref.WeakKeyDictionary() @@ -686,20 +686,22 @@ class Mapper(object): return mapper_registry[self.class_key] def is_assigned(self, instance): - """returns True if this mapper handles the given instance. this is dependent - not only on class assignment but the optional "entity_name" parameter as well.""" + """return True if this mapper handles the given instance. + + this is dependent not only on class assignment but the optional "entity_name" parameter as well.""" return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name def _assign_entity_name(self, instance): - """assigns this Mapper's entity name to the given instance. subsequent Mapper lookups for this - instance will return the primary mapper corresponding to this Mapper's class and entity name.""" + """assign this Mapper's entity name to the given instance. + + subsequent Mapper lookups for this instance will return the primary + mapper corresponding to this Mapper's class and entity name.""" instance._entity_name = self.entity_name def get_session(self): - """returns the contextual session provided by the mapper extension chain + """return the contextual session provided by the mapper extension chain, if any. - raises InvalidRequestError if a session cannot be retrieved from the - extension chain + raises InvalidRequestError if a session cannot be retrieved from the extension chain """ self.compile() s = self.extension.get_session() @@ -708,52 +710,50 @@ class Mapper(object): return s def has_eager(self): - """returns True if one of the properties attached to this Mapper is eager loading""" + """return True if one of the properties attached to this Mapper is eager loading""" return getattr(self, '_has_eager', False) - def instances(self, cursor, session, *mappers, **kwargs): - """given a cursor (ResultProxy) from an SQLEngine, returns a list of object instances - corresponding to the rows in the cursor.""" - self.__log_debug("instances()") - self.compile() + """return a list of mapped instances corresponding to the rows in a given ResultProxy.""" + return querylib.Query(self, session).instances(cursor, *mappers, **kwargs) + + def identity_key_from_row(self, row): + """return an identity-map key for use in storing/retrieving an item from the identity map. + + row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set + column names to their values within a row. + """ + return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name) - context = SelectionContext(self, session, **kwargs) + def identity_key_from_primary_key(self, primary_key): + """return an identity-map key for use in storing/retrieving an item from an identity map. - result = util.UniqueAppender([]) - if mappers: - otherresults = [] - for m in mappers: - otherresults.append(util.UniqueAppender([])) - - for row in cursor.fetchall(): - self._instance(context, row, result) - i = 0 - for m in mappers: - m._instance(context, row, otherresults[i]) - i+=1 - - # store new stuff in the identity map - for value in context.identity_map.values(): - session._register_persistent(value) - - if mappers: - return [result.data] + [o.data for o in otherresults] - else: - return result.data + primary_key - a list of values indicating the identifier. + """ + return (self.class_, tuple(util.to_list(primary_key)), self.entity_name) + + def identity_key_from_instance(self, instance): + """return the identity key for the given instance, based on its primary key attributes. - def identity_key(self, primary_key): - """returns the instance key for the given identity value. this is a global tracking object used by the Session, and is usually available off a mapped object as instance._instance_key.""" - return sessionlib.get_id_key(util.to_list(primary_key), self.class_, self.entity_name) + this value is typically also found on the instance itself under the attribute name '_instance_key'. + """ + return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) + + def primary_key_from_instance(self, instance): + """return the list of primary key values for the given instance.""" + return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]] def instance_key(self, instance): - """returns the instance key for the given instance. this is a global tracking object used by the Session, and is usually available off a mapped object as instance._instance_key.""" - return self.identity_key(self.identity(instance)) + """deprecated. a synonym for identity_key_from_instance.""" + return self.identity_key_from_instance(instance) + + def identity_key(self, primary_key): + """deprecated. a synonym for identity_key_from_primary_key.""" + return self.identity_key_from_primary_key(primary_key) def identity(self, instance): - """returns the identity (list of primary key values) for the given instance. The list of values can be fed directly into the get() method as mapper.get(*key).""" - return [self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.mapped_table]] - + """deprecated. a synoynm for primary_key_from_instance.""" + return self.primary_key_from_instance(instance) def _getpropbycolumn(self, column, raiseerror=True): try: @@ -1090,8 +1090,6 @@ class Mapper(object): for prop in self.__props.values(): prop.cascade_callable(type, object, callable_, recursive) - def _row_identity_key(self, row): - return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name) def get_select_mapper(self): """return the mapper used for issuing selects. @@ -1117,7 +1115,7 @@ class Mapper(object): # been exposed to being modified by the application. populate_existing = context.populate_existing or self.always_refresh - identitykey = self._row_identity_key(row) + identitykey = self.identity_key_from_row(row) if context.session.has_key(identitykey): instance = context.session._get(identitykey) self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey))) @@ -1202,36 +1200,6 @@ class Mapper(object): Mapper.logger = logging.class_logger(Mapper) -class SelectionContext(OperationContext): - """created within the mapper.instances() method to store and share - state among all the Mappers and MapperProperty objects used in a load operation. - - SelectionContext contains these attributes: - - mapper - the Mapper which originated the instances() call. - - session - the Session that is relevant to the instances call. - - identity_map - a dictionary which stores newly created instances that have - not yet been added as persistent to the Session. - - attributes - a dictionary to store arbitrary data; eager loaders use it to - store additional result lists - - populate_existing - indicates if its OK to overwrite the attributes of instances - that were already in the Session - - version_check - indicates if mappers that have version_id columns should verify - that instances existing already within the Session should have this attribute compared - to the freshly loaded value - - """ - def __init__(self, mapper, session, **kwargs): - self.populate_existing = kwargs.pop('populate_existing', False) - self.version_check = kwargs.pop('version_check', False) - self.session = session - self.identity_map = {} - super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs) class MapperExtension(object): """base implementation for an object that provides overriding behavior to various @@ -1378,6 +1346,8 @@ def has_identity(object): def has_mapper(object): """returns True if the given object has a mapper association""" return hasattr(object, '_entity_name') + + def object_mapper(object, raiseerror=True): """given an object, returns the primary Mapper associated with the object instance""" diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index b436adb014..a7021d4722 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -5,11 +5,13 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import session as sessionlib -from sqlalchemy import sql, util, exceptions, sql_util +from sqlalchemy import sql, util, exceptions, sql_util, logging import mapper from interfaces import OperationContext +__all__ = ['Query', 'QueryContext', 'SelectionContext'] + class Query(object): """encapsulates the object-fetching operations provided by Mappers.""" def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, **kwargs): @@ -244,7 +246,7 @@ class Query(object): def select_text(self, text, **params): t = sql.text(text) - return self.instances(t, params=params) + return self.execute(t, params=params) def options(self, *args, **kwargs): """returns a new Query object using the given MapperOptions.""" @@ -268,12 +270,43 @@ class Query(object): else: raise AttributeError(key) - def instances(self, clauseelement, params=None, *args, **kwargs): + def execute(self, clauseelement, params=None, *args, **kwargs): result = self.session.execute(self.mapper, clauseelement, params=params) try: return self.mapper.instances(result, self.session, with_options=self.with_options, **kwargs) finally: result.close() + + def instances(self, cursor, *mappers, **kwargs): + """return a list of mapped instances corresponding to the rows in a given ResultProxy.""" + self.__log_debug("instances()") + + session = self.session + + context = SelectionContext(self.mapper, session, **kwargs) + + result = util.UniqueAppender([]) + if mappers: + otherresults = [] + for m in mappers: + otherresults.append(util.UniqueAppender([])) + + for row in cursor.fetchall(): + self.mapper._instance(context, row, result) + i = 0 + for m in mappers: + m._instance(context, row, otherresults[i]) + i+=1 + + # store new stuff in the identity map + for value in context.identity_map.values(): + session._register_persistent(value) + + if mappers: + return [result.data] + [o.data for o in otherresults] + else: + return result.data + def _get(self, key, ident=None, reload=False, lockmode=None): lockmode = lockmode or self.lockmode @@ -308,7 +341,7 @@ class Query(object): statement.use_labels = True if params is None: params = {} - return self.instances(statement, params=params, **kwargs) + return self.execute(statement, params=params, **kwargs) def _should_nest(self, querycontext): """return True if the given statement options indicate that we should "nest" the @@ -397,6 +430,11 @@ class Query(object): return statement + def __log_debug(self, msg): + self.logger.debug(msg) + +Query.logger = logging.class_logger(Query) + class QueryContext(OperationContext): """created within the Query.compile() method to store and share state among all the Mappers and MapperProperty objects used in a query construction.""" @@ -412,4 +450,40 @@ class QueryContext(OperationContext): super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs) def select_args(self): return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct} + def accept_option(self, opt): + opt.process_query_context(self) + + +class SelectionContext(OperationContext): + """created within the query.instances() method to store and share + state among all the Mappers and MapperProperty objects used in a load operation. + + SelectionContext contains these attributes: + + mapper - the Mapper which originated the instances() call. + + session - the Session that is relevant to the instances call. + + identity_map - a dictionary which stores newly created instances that have + not yet been added as persistent to the Session. + + attributes - a dictionary to store arbitrary data; eager loaders use it to + store additional result lists + + populate_existing - indicates if its OK to overwrite the attributes of instances + that were already in the Session + + version_check - indicates if mappers that have version_id columns should verify + that instances existing already within the Session should have this attribute compared + to the freshly loaded value + + """ + def __init__(self, mapper, session, **kwargs): + self.populate_existing = kwargs.pop('populate_existing', False) + self.version_check = kwargs.pop('version_check', False) + self.session = session + self.identity_map = {} + super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs) + def accept_option(self, opt): + opt.process_selection_context(self) \ No newline at end of file diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3ec3044b8d..ca64faebe8 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -140,15 +140,17 @@ class Session(object): self.uow.echo = echo def mapper(self, class_, entity_name=None): - """given an Class, returns the primary Mapper responsible for persisting it""" + """given an Class, return the primary Mapper responsible for persisting it""" return class_mapper(class_, entity_name = entity_name) def bind_mapper(self, mapper, bindto): - """binds the given Mapper to the given Engine or Connection. All subsequent operations involving this - Mapper will use the given bindto.""" + """bind the given Mapper to the given Engine or Connection. + + All subsequent operations involving this Mapper will use the given bindto.""" self.binds[mapper] = bindto def bind_table(self, table, bindto): - """binds the given Table to the given Engine or Connection. All subsequent operations involving this - Table will use the given bindto.""" + """bind the given Table to the given Engine or Connection. + + All subsequent operations involving this Table will use the given bindto.""" self.binds[table] = bindto def get_bind(self, mapper): """return the Engine or Connection which is used to execute statements on behalf of the given Mapper. @@ -198,36 +200,6 @@ class Session(object): sql = property(_sql) - - def get_id_key(ident, class_, entity_name=None): - """return an identity-map key for use in storing/retrieving an item from the identity map. - - ident - a tuple of primary key values corresponding to the object to be stored. these - values should be in the same order as the primary keys of the table - - class_ - a reference to the object's class - - entity_name - optional string name to further qualify the class - """ - return (class_, tuple(ident), entity_name) - get_id_key = staticmethod(get_id_key) - - def get_row_key(row, class_, primary_key, entity_name=None): - """return an identity-map key for use in storing/retrieving an item from the identity map. - - row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set - column names to their values within a row. - - class_ - a reference to the object's class - - primary_key - a list of column objects that will target the primary key values - in the given row. - - entity_name - optional string name to further qualify the class - """ - return (class_, tuple([row[column] for column in primary_key]), entity_name) - get_row_key = staticmethod(get_row_key) - def flush(self, objects=None): """flush all the object modifications present in this session to the database. @@ -265,8 +237,12 @@ class Session(object): raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj)) def expire(self, obj): - """invalidate the data in the given object and sets them to refresh themselves - the next time they are requested.""" + """mark the given object as expired. + + this will add an instrumentation to all mapped attributes on the instance such that when + an attribute is next accessed, the session will reload all attributes on the instance + from the database. + """ self._validate_persistent(obj) def exp(): if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None: @@ -274,6 +250,7 @@ class Session(object): attribute_manager.trigger_history(obj, exp) def is_expired(self, obj, unexpire=False): + """return True if the given object has been marked as expired.""" ret = attribute_manager.has_trigger(obj) if ret and unexpire: attribute_manager.untrigger_history(obj) @@ -290,8 +267,10 @@ class Session(object): def save(self, object, entity_name=None): """ - Adds a transient (unsaved) instance to this Session. This operation cascades the "save_or_update" - method to associated instances if the relation is mapped with cascade="save-update". + Add a transient (unsaved) instance to this Session. + + This operation cascades the "save_or_update" method to associated instances if the + relation is mapped with cascade="save-update". The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this instance. @@ -300,15 +279,21 @@ class Session(object): object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) def update(self, object, entity_name=None): - """Brings the given detached (saved) instance into this Session. - If there is a persistent instance with the same identifier (i.e. a saved instance already associated with this - Session), an exception is thrown. + """Bring the given detached (saved) instance into this Session. + + If there is a persistent instance with the same identifier already associated + with this Session, an exception is thrown. + This operation cascades the "save_or_update" method to associated instances if the relation is mapped with cascade="save-update".""" self._update_impl(object, entity_name=entity_name) object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) def save_or_update(self, object, entity_name=None): + """save or update the given object into this Session. + + The presence of an '_instance_key' attribute on the instance determines whether to + save() or update() the instance.""" self._save_or_update_impl(object, entity_name=entity_name) object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) @@ -320,11 +305,16 @@ class Session(object): self._update_impl(object, entity_name=entity_name) def delete(self, object, entity_name=None): - #self.uow.register_deleted(object) + """mark the given instance as deleted. + + the delete operation occurs upon flush().""" for c in [object] + list(object_mapper(object).cascade_iterator('delete', object)): self.uow.register_deleted(c) def merge(self, object, entity_name=None): + """merge the object into a newly loaded or existing instance from this Session. + + note: this method is currently not completely implemented.""" instance = None for obj in [object] + list(object_mapper(object).cascade_iterator('merge', object)): key = getattr(obj, '_instance_key', None) @@ -430,12 +420,6 @@ class Session(object): """deprecated; a synynom for merge()""" return self.merge(*args, **kwargs) -def get_id_key(ident, class_, entity_name=None): - return Session.get_id_key(ident, class_, entity_name) - -def get_row_key(row, class_, primary_key, entity_name=None): - return Session.get_row_key(row, class_, primary_key, entity_name) - def object_mapper(obj): return sqlalchemy.orm.object_mapper(obj) @@ -453,6 +437,7 @@ attribute_manager = unitofwork.attribute_manager _sessions = weakref.WeakValueDictionary() def object_session(obj): + """return the Session to which the given object is bound, or None if none.""" hashkey = getattr(obj, '_sa_session_id', None) if hashkey is not None: return _sessions.get(hashkey) @@ -460,9 +445,3 @@ def object_session(obj): unitofwork.object_session = object_session - -def get_session(obj=None): - """deprecated""" - if obj is not None: - return object_session(obj) - raise exceptions.InvalidRequestError("get_session() is deprecated, and does not return the thread-local session anymore. Use the SessionContext.mapper_extension or import sqlalchemy.mod.threadlocal to establish a default thread-local context.") diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 6bbeb65f1c..88d7f6c52e 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -458,10 +458,24 @@ class EagerLoader(AbstractRelationLoader): return try: - clauses = self.clauses_by_lead_mapper[selectcontext.mapper] - decorated_row = clauses._decorate_row(row) + # decorate the row according to the stored AliasedClauses for this eager load, + # or look for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option) + if selectcontext.attributes.has_key((EagerLoader, self)): + # custom row decoration function, placed in the selectcontext by the + # contains_eager() mapper option + decorator = selectcontext.attributes[(EagerLoader, self)] + if decorator is None: + decorated_row = row + else: + decorated_row = decorator(row) + print "OK! ROW IS", decorated_row + else: + # AliasedClauses, keyed to the lead mapper used in the query + clauses = self.clauses_by_lead_mapper[selectcontext.mapper] + decorated_row = clauses._decorate_row(row) + print "OK! DECORATED ROW IS", decorated_row # check for identity key - identity_key = self.mapper._row_identity_key(decorated_row) + identity_key = self.mapper.identity_key_from_row(decorated_row) except KeyError: # else degrade to a lazy loader self.logger.debug("degrade to lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) @@ -513,5 +527,11 @@ class EagerLazyOption(StrategizedOption): elif self.lazy is None: return NoLoader - +class RowDecorateOption(PropertyOption): + def __init__(self, key, decorator=None): + super(RowDecorateOption, self).__init__(key) + self.decorator = decorator + def process_selection_property(self, context, property): + context.attributes[(EagerLoader, property)] = self.decorator + diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 5c2e21c5fa..c4fd92e361 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -14,12 +14,11 @@ an "identity map" pattern. The Unit of Work then maintains lists of objects tha dirty, or deleted and provides the capability to flush all those changes at once. """ -from sqlalchemy import attributes, util, logging +from sqlalchemy import attributes, util, logging, topological import sqlalchemy from sqlalchemy.exceptions import * import StringIO import weakref -import topological import sets class UOWEventHandler(attributes.AttributeExtension): @@ -203,8 +202,6 @@ class UOWTransaction(object): self.mappers = util.Set() self.dependencies = {} self.tasks = {} - self.__modified = False - self.__is_executing = False self.logger = logging.instance_logger(self) self.echo = uow.echo @@ -229,8 +226,7 @@ class UOWTransaction(object): self.mappers.add(mapper) task = self.get_task_by_mapper(mapper) if postupdate: - mod = task.append_postupdate(obj, post_update_cols) - if mod: self._mark_modified() + task.append_postupdate(obj, post_update_cols) return # for a cyclical task, things need to be sorted out already, @@ -239,8 +235,7 @@ class UOWTransaction(object): if task.circular: return - mod = task.append(obj, listonly, isdelete=isdelete, **kwargs) - if mod: self._mark_modified() + task.append(obj, listonly, isdelete=isdelete, **kwargs) def unregister_object(self, obj): #print "UNREGISTER", obj @@ -248,12 +243,6 @@ class UOWTransaction(object): task = self.get_task_by_mapper(mapper) if obj in task.objects: task.delete(obj) - self._mark_modified() - - def _mark_modified(self): - #if self.__is_executing: - # raise "test assertion failed" - self.__modified = True def is_deleted(self, obj): @@ -287,7 +276,6 @@ class UOWTransaction(object): dependency = dependency.primary_mapper().base_mapper() self.dependencies[(mapper, dependency)] = True - self._mark_modified() def register_processor(self, mapper, processor, mapperfrom): """called by mapper.PropertyLoader to register itself as a "processor", which @@ -307,7 +295,6 @@ class UOWTransaction(object): targettask = self.get_task_by_mapper(mapperfrom) up = UOWDependencyProcessor(processor, targettask) task.dependencies.add(up) - self._mark_modified() def execute(self): # insure that we have a UOWTask for every mapper that will be involved @@ -328,14 +315,7 @@ class UOWTransaction(object): if not ret: break - # flip the execution flag on. in some test cases - # we like to check this flag against any new objects being added, since everything - # should be registered by now. there is a slight exception in the case of - # post_update requests; this should be fixed. - self.__is_executing = True - head = self._sort_dependencies() - self.__modified = False if self.echo: if head is None: self.logger.info("Task dump: None") @@ -343,8 +323,6 @@ class UOWTransaction(object): self.logger.info("Task dump:\n" + head.dump()) if head is not None: head.execute(self) - #if self.__modified and head is not None: - # raise "Assertion failed ! new pre-execute dependency step should eliminate post-execute changes (except post_update stuff)." self.logger.info("Execute Complete") def post_exec(self): @@ -391,182 +369,25 @@ class UOWTransaction(object): mappers.add(base) return mappers - -class UOWTaskElement(object): - """an element within a UOWTask. corresponds to a single object instance - to be saved, deleted, or just part of the transaction as a placeholder for - further dependencies (i.e. 'listonly'). - in the case of self-referential mappers, may also store a list of childtasks, - further UOWTasks containing objects dependent on this element's object instance.""" - def __init__(self, obj): - self.obj = obj - self.__listonly = True - self.childtasks = [] - self.__isdelete = False - self.__preprocessed = {} - def _get_listonly(self): - return self.__listonly - def _set_listonly(self, value): - """set_listonly is a one-way setter, will only go from True to False.""" - if not value and self.__listonly: - self.__listonly = False - self.clear_preprocessed() - def _get_isdelete(self): - return self.__isdelete - def _set_isdelete(self, value): - if self.__isdelete is not value: - self.__isdelete = value - self.clear_preprocessed() - listonly = property(_get_listonly, _set_listonly) - isdelete = property(_get_isdelete, _set_isdelete) - - def mark_preprocessed(self, processor): - """marks this element as "preprocessed" by a particular UOWDependencyProcessor. preprocessing is the step - which sweeps through all the relationships on all the objects in the flush transaction and adds other objects - which are also affected, In some cases it can switch an object from "tosave" to "todelete". changes to the state - of this UOWTaskElement will reset all "preprocessed" flags, causing it to be preprocessed again. When all UOWTaskElements - have been fully preprocessed by all UOWDependencyProcessors, then the topological sort can be done.""" - self.__preprocessed[processor] = True - def is_preprocessed(self, processor): - return self.__preprocessed.get(processor, False) - def clear_preprocessed(self): - self.__preprocessed.clear() - def __repr__(self): - return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) ) - -class UOWDependencyProcessor(object): - """in between the saving and deleting of objects, process "dependent" data, such as filling in - a foreign key on a child item from a new primary key, or deleting association rows before a - delete. This object acts as a proxy to a DependencyProcessor.""" - def __init__(self, processor, targettask): - self.processor = processor - self.targettask = targettask - def __eq__(self, other): - return other.processor is self.processor and other.targettask is self.targettask - def __hash__(self): - return hash((self.processor, self.targettask)) - - def preexecute(self, trans): - """traverses all objects handled by this dependency processor and locates additional objects which should be - part of the transaction, such as those affected deletes, orphans to be deleted, etc. Returns True if any - objects were preprocessed, or False if no objects were preprocessed.""" - def getobj(elem): - elem.mark_preprocessed(self) - return elem.obj - - ret = False - elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)] - if len(elements): - ret = True - self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - - elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)] - if len(elements): - ret = True - self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) - return ret - - def execute(self, trans, delete): - if not delete: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False) - else: - self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True) - - def get_object_dependencies(self, obj, trans, passive): - return self.processor.get_object_dependencies(obj, trans, passive=passive) - - def whose_dependent_on_who(self, obj, o): - return self.processor.whose_dependent_on_who(obj, o) - - def branch(self, task): - return UOWDependencyProcessor(self.processor, task) - -class UOWExecutor(object): - def execute(self, trans, task, isdelete=None): - if isdelete is not True: - self.execute_save_steps(trans, task) - if isdelete is not False: - self.execute_delete_steps(trans, task) - - def save_objects(self, trans, task): - task._save_objects(trans) - - def delete_objects(self, trans, task): - task._delete_objects(trans) - - def execute_dependency(self, trans, dep, isdelete): - dep.execute(trans, isdelete) - - def execute_save_steps(self, trans, task): - if task.circular is not None: - self.execute_save_steps(trans, task.circular) - else: - self.save_objects(trans, task) - self.execute_cyclical_dependencies(trans, task, False) - self.execute_per_element_childtasks(trans, task, False) - self.execute_dependencies(trans, task, False) - self.execute_dependencies(trans, task, True) - self.execute_childtasks(trans, task, False) - - def execute_delete_steps(self, trans, task): - if task.circular is not None: - self.execute_delete_steps(trans, task.circular) - else: - self.execute_cyclical_dependencies(trans, task, True) - self.execute_childtasks(trans, task, True) - self.execute_per_element_childtasks(trans, task, True) - self.delete_objects(trans, task) - - def execute_dependencies(self, trans, task, isdelete=None): - alltasks = list(task.polymorphic_tasks()) - if isdelete is not True: - for task in alltasks: - for dep in task.dependencies: - self.execute_dependency(trans, dep, False) - if isdelete is not False: - alltasks.reverse() - for task in alltasks: - for dep in task.dependencies: - self.execute_dependency(trans, dep, True) - - def execute_childtasks(self, trans, task, isdelete=None): - for polytask in task.polymorphic_tasks(): - for child in polytask.childtasks: - self.execute(trans, child, isdelete) - - def execute_cyclical_dependencies(self, trans, task, isdelete): - for polytask in task.polymorphic_tasks(): - for dep in polytask.cyclical_dependencies: - self.execute_dependency(trans, dep, isdelete) - - def execute_per_element_childtasks(self, trans, task, isdelete): - for polytask in task.polymorphic_tasks(): - for element in polytask.tosave_elements + polytask.todelete_elements: - self.execute_element_childtasks(trans, element, isdelete) - - def execute_element_childtasks(self, trans, element, isdelete): - for child in element.childtasks: - self.execute(trans, child, isdelete) - class UOWTask(object): """represents the full list of objects that are to be saved/deleted by a specific Mapper.""" def __init__(self, uowtransaction, mapper, circular_parent=None): if not circular_parent: uowtransaction.tasks[mapper] = self - + # the transaction owning this UOWTask self.uowtransaction = uowtransaction - + # the Mapper which this UOWTask corresponds to self.mapper = mapper - + # a dictionary mapping object instances to a corresponding UOWTaskElement. # Each UOWTaskElement represents one instance which is to be saved or # deleted by this UOWTask's Mapper. # in the case of the row-based "circular sort", the UOWTaskElement may # also reference further UOWTasks which are dependent on that UOWTaskElement. self.objects = {} #util.OrderedDict() - + # a list of UOWDependencyProcessors which are executed after saves and # before deletes, to synchronize data to dependent objects self.dependencies = util.Set() @@ -575,7 +396,7 @@ class UOWTask(object): # are to be executed after this UOWTask performs saves and post-save # dependency processing, and before pre-delete processing and deletes self.childtasks = [] - + # whether this UOWTask is circular, meaning it holds a second # UOWTask that contains a special row-based dependency structure. self.circular = None @@ -583,16 +404,16 @@ class UOWTask(object): # for a task thats part of that row-based dependency structure, points # back to the "public facing" task. self.circular_parent = circular_parent - + # a list of UOWDependencyProcessors are derived from the main # set of dependencies, referencing sub-UOWTasks attached to this # one which represent portions of the total list of objects. # this is used for the row-based "circular sort" self.cyclical_dependencies = util.Set() - + def is_empty(self): return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0 - + def append(self, obj, listonly = False, childtask = None, isdelete = False): """appends an object to this task, to be either saved or deleted depending on the 'isdelete' attribute of this UOWTask. 'listonly' indicates that the object should @@ -614,14 +435,14 @@ class UOWTask(object): if isdelete: rec.isdelete = True return retval - + def append_postupdate(self, obj, post_update_cols): # postupdates are UPDATED immeditely (for now) # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns # instead of __eq__ self.mapper.save_obj([obj], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) return True - + def delete(self, obj): try: del self.objects[obj] @@ -633,11 +454,11 @@ class UOWTask(object): def _delete_objects(self, trans): for task in self.polymorphic_tasks(): task.mapper.delete_obj(task.todelete_objects, trans) - + def execute(self, trans): """executes this UOWTask. saves objects to be saved, processes all dependencies that have been registered, and deletes objects to be deleted. """ - + UOWExecutor().execute(trans, self) def polymorphic_tasks(self): @@ -645,10 +466,10 @@ class UOWTask(object): mappers are inheriting descendants of this UOWTask's mapper. UOWTasks are returned in order of their hierarchy to each other, meaning if UOWTask B's mapper inherits from UOWTask A's mapper, then UOWTask B will appear after UOWTask A in the iteration.""" - + # first us yield self - + # "circular dependency" tasks aren't polymorphic if self.circular_parent is not None: return @@ -662,12 +483,12 @@ class UOWTask(object): else: for t in _tasks_by_mapper(m): yield t - + # main yield loop for task in _tasks_by_mapper(self.mapper): for t in task.polymorphic_tasks(): yield t - + def contains_object(self, obj, polymorphic=False): if polymorphic: for task in self.polymorphic_tasks(): @@ -680,13 +501,13 @@ class UOWTask(object): def is_inserted(self, obj): return not hasattr(obj, '_instance_key') - + def is_deleted(self, obj): try: return self.objects[obj].isdelete except KeyError: return False - + def get_elements(self, polymorphic=False): if polymorphic: for task in self.polymorphic_tasks(): @@ -695,7 +516,7 @@ class UOWTask(object): else: for rec in self.objects.values(): yield rec - + polymorphic_tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if not rec.isdelete]) polymorphic_todelete_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if rec.isdelete]) tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=False) if not rec.isdelete]) @@ -703,30 +524,30 @@ class UOWTask(object): tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is False]) todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is True]) polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True) if rec.obj is not None and not rec.listonly and rec.isdelete is False]) - + def _sort_circular_dependencies(self, trans, cycles): """for a single task, creates a hierarchical tree of "subtasks" which associate specific dependency actions with individual objects. This is used for a "cyclical" task, or a task where elements of its object list contain dependencies on each other. - + this is not the normal case; this logic only kicks in when something like a hierarchical tree is being represented.""" allobjects = [] for task in cycles: allobjects += [e.obj for e in task.get_elements(polymorphic=True)] tuples = [] - + cycles = util.Set(cycles) - + #print "BEGIN CIRC SORT-------" #print "PRE-CIRC:" #print list(cycles) #[0].dump() - + # dependency processors that arent part of the cyclical thing # get put here extradeplist = [] - + # organizes a set of new UOWTasks that will be assembled into # the final tree, for the purposes of holding new UOWDependencyProcessors # which process small sub-sections of dependent parent/child operations @@ -748,7 +569,7 @@ class UOWTask(object): proctask = trans.get_task_by_mapper(dep.processor.mapper.primary_mapper().base_mapper(), True) targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True) return targettask in cycles and (proctask is not None and proctask in cycles) - + # organize all original UOWDependencyProcessors by their target task deps_by_targettask = {} for t in cycles: @@ -761,14 +582,14 @@ class UOWTask(object): l.append(dep) object_to_original_task = {} - + for t in cycles: for task in t.polymorphic_tasks(): for taskelement in task.get_elements(polymorphic=False): obj = taskelement.obj object_to_original_task[obj] = task #print "OBJ", repr(obj), "TASK", repr(task) - + for dep in deps_by_targettask.get(task, []): # is this dependency involved in one of the cycles ? #print "DEP iterate", dep.processor.key, dep.processor.parent, dep.processor.mapper @@ -778,16 +599,16 @@ class UOWTask(object): #print "DEP", dep.processor.key (processor, targettask) = (dep.processor, dep.targettask) isdelete = taskelement.isdelete - + # list of dependent objects from this object childlist = dep.get_object_dependencies(obj, trans, passive=True) if childlist is None: continue # the task corresponding to saving/deleting of those dependent objects childtask = trans.get_task_by_mapper(processor.mapper.primary_mapper()) - + childlist = childlist.added_items() + childlist.unchanged_items() + childlist.deleted_items() - + for o in childlist: if o is None or not childtask.contains_object(o, polymorphic=True): continue @@ -804,7 +625,7 @@ class UOWTask(object): get_dependency_task(whosdep[0], dep).append(whosdep[1], isdelete=isdelete) else: get_dependency_task(obj, dep).append(obj, isdelete=isdelete) - + #print "TUPLES", tuples head = DependencySorter(tuples, allobjects).sort() if head is None: @@ -823,7 +644,7 @@ class UOWTask(object): nexttasks[originating_task] = t parenttask.append(None, listonly=False, isdelete=originating_task.objects[node.item].isdelete, childtask=t) t.append(node.item, originating_task.objects[node.item].listonly, isdelete=originating_task.objects[node.item].isdelete) - + if dependencies.has_key(node.item): for depprocessor, deptask in dependencies[node.item].iteritems(): t.cyclical_dependencies.add(depprocessor.branch(deptask)) @@ -848,8 +669,8 @@ class UOWTask(object): import uowdumper uowdumper.UOWDumper(self, buf) return buf.getvalue() - - + + def __repr__(self): if self.mapper is not None: if self.mapper.__class__.__name__ == 'Mapper': @@ -859,6 +680,164 @@ class UOWTask(object): else: name = '(none)' return ("UOWTask(%d) Mapper: '%s'" % (id(self), name)) + +class UOWTaskElement(object): + """an element within a UOWTask. corresponds to a single object instance + to be saved, deleted, or just part of the transaction as a placeholder for + further dependencies (i.e. 'listonly'). + in the case of self-referential mappers, may also store a list of childtasks, + further UOWTasks containing objects dependent on this element's object instance.""" + def __init__(self, obj): + self.obj = obj + self.__listonly = True + self.childtasks = [] + self.__isdelete = False + self.__preprocessed = {} + def _get_listonly(self): + return self.__listonly + def _set_listonly(self, value): + """set_listonly is a one-way setter, will only go from True to False.""" + if not value and self.__listonly: + self.__listonly = False + self.clear_preprocessed() + def _get_isdelete(self): + return self.__isdelete + def _set_isdelete(self, value): + if self.__isdelete is not value: + self.__isdelete = value + self.clear_preprocessed() + listonly = property(_get_listonly, _set_listonly) + isdelete = property(_get_isdelete, _set_isdelete) + + def mark_preprocessed(self, processor): + """marks this element as "preprocessed" by a particular UOWDependencyProcessor. preprocessing is the step + which sweeps through all the relationships on all the objects in the flush transaction and adds other objects + which are also affected, In some cases it can switch an object from "tosave" to "todelete". changes to the state + of this UOWTaskElement will reset all "preprocessed" flags, causing it to be preprocessed again. When all UOWTaskElements + have been fully preprocessed by all UOWDependencyProcessors, then the topological sort can be done.""" + self.__preprocessed[processor] = True + def is_preprocessed(self, processor): + return self.__preprocessed.get(processor, False) + def clear_preprocessed(self): + self.__preprocessed.clear() + def __repr__(self): + return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) ) + +class UOWDependencyProcessor(object): + """in between the saving and deleting of objects, process "dependent" data, such as filling in + a foreign key on a child item from a new primary key, or deleting association rows before a + delete. This object acts as a proxy to a DependencyProcessor.""" + def __init__(self, processor, targettask): + self.processor = processor + self.targettask = targettask + def __eq__(self, other): + return other.processor is self.processor and other.targettask is self.targettask + def __hash__(self): + return hash((self.processor, self.targettask)) + + def preexecute(self, trans): + """traverses all objects handled by this dependency processor and locates additional objects which should be + part of the transaction, such as those affected deletes, orphans to be deleted, etc. Returns True if any + objects were preprocessed, or False if no objects were preprocessed.""" + def getobj(elem): + elem.mark_preprocessed(self) + return elem.obj + + ret = False + elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)] + if len(elements): + ret = True + self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) + + elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None and not elem.is_preprocessed(self)] + if len(elements): + ret = True + self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) + return ret + + def execute(self, trans, delete): + if not delete: + self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False) + else: + self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True) + + def get_object_dependencies(self, obj, trans, passive): + return self.processor.get_object_dependencies(obj, trans, passive=passive) + + def whose_dependent_on_who(self, obj, o): + return self.processor.whose_dependent_on_who(obj, o) + + def branch(self, task): + return UOWDependencyProcessor(self.processor, task) + +class UOWExecutor(object): + """encapsulates the execution traversal of a UOWTransaction structure.""" + def execute(self, trans, task, isdelete=None): + if isdelete is not True: + self.execute_save_steps(trans, task) + if isdelete is not False: + self.execute_delete_steps(trans, task) + + def save_objects(self, trans, task): + task._save_objects(trans) + + def delete_objects(self, trans, task): + task._delete_objects(trans) + + def execute_dependency(self, trans, dep, isdelete): + dep.execute(trans, isdelete) + + def execute_save_steps(self, trans, task): + if task.circular is not None: + self.execute_save_steps(trans, task.circular) + else: + self.save_objects(trans, task) + self.execute_cyclical_dependencies(trans, task, False) + self.execute_per_element_childtasks(trans, task, False) + self.execute_dependencies(trans, task, False) + self.execute_dependencies(trans, task, True) + self.execute_childtasks(trans, task, False) + + def execute_delete_steps(self, trans, task): + if task.circular is not None: + self.execute_delete_steps(trans, task.circular) + else: + self.execute_cyclical_dependencies(trans, task, True) + self.execute_childtasks(trans, task, True) + self.execute_per_element_childtasks(trans, task, True) + self.delete_objects(trans, task) + + def execute_dependencies(self, trans, task, isdelete=None): + alltasks = list(task.polymorphic_tasks()) + if isdelete is not True: + for task in alltasks: + for dep in task.dependencies: + self.execute_dependency(trans, dep, False) + if isdelete is not False: + alltasks.reverse() + for task in alltasks: + for dep in task.dependencies: + self.execute_dependency(trans, dep, True) + + def execute_childtasks(self, trans, task, isdelete=None): + for polytask in task.polymorphic_tasks(): + for child in polytask.childtasks: + self.execute(trans, child, isdelete) + + def execute_cyclical_dependencies(self, trans, task, isdelete): + for polytask in task.polymorphic_tasks(): + for dep in polytask.cyclical_dependencies: + self.execute_dependency(trans, dep, isdelete) + + def execute_per_element_childtasks(self, trans, task, isdelete): + for polytask in task.polymorphic_tasks(): + for element in polytask.tosave_elements + polytask.todelete_elements: + self.execute_element_childtasks(trans, element, isdelete) + + def execute_element_childtasks(self, trans, element, isdelete): + for child in element.childtasks: + self.execute(trans, child, isdelete) + class DependencySorter(topological.QueueDependencySorter): pass diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 4935b1adda..bfbcff5541 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -1,6 +1,4 @@ -import sqlalchemy.sql as sql -import sqlalchemy.schema as schema -import sqlalchemy.util as util +from sqlalchemy import sql, util, schema, topological """utility functions that build upon SQL and Schema constructs""" @@ -36,7 +34,6 @@ class TableCollection(object): return sorted def _do_sort(self): - import sqlalchemy.orm.topological tuples = [] class TVisitor(schema.SchemaVisitor): def visit_foreign_key(_self, fkey): @@ -49,7 +46,7 @@ class TableCollection(object): vis = TVisitor() for table in self.tables: table.accept_schema_visitor(vis) - sorter = sqlalchemy.orm.topological.QueueDependencySorter( tuples, self.tables ) + sorter = topological.QueueDependencySorter( tuples, self.tables ) head = sorter.sort() sequence = [] def to_sequence( node, seq=sequence): diff --git a/lib/sqlalchemy/orm/topological.py b/lib/sqlalchemy/topological.py similarity index 98% rename from lib/sqlalchemy/orm/topological.py rename to lib/sqlalchemy/topological.py index 8c481b6f11..948b3cfead 100644 --- a/lib/sqlalchemy/orm/topological.py +++ b/lib/sqlalchemy/topological.py @@ -36,7 +36,7 @@ import sqlalchemy.util as util from sqlalchemy.exceptions import * class QueueDependencySorter(object): - """this is a topological sort from wikipedia. its very stable. it creates a straight-line + """topological sort adapted from wikipedia's article on the subject. it creates a straight-line list of elements, then a second pass groups non-dependent actions together to build more of a tree structure with siblings.""" class Node: diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 6a26c8b4be..a048adbd87 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -477,6 +477,7 @@ class MapperTest(MapperSuperTest): print u[0].orders[1].items[0].keywords[1] self.assert_sql_count(db, go, 3) sess.clear() + print "MARK" u = q2.select() self.assert_sql_count(db, go, 2) @@ -873,6 +874,19 @@ class EagerTest(MapperSuperTest): {'user_id' : 9, 'addresses' : (Address, [])} ) + def testcustom(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False) + }) + mapper(Address, addresses) + + selectquery = users.outerjoin(addresses).select(use_labels=True) + q = create_session().query(User) + + l = q.options(contains_eager('addresses')).instances(selectquery.execute()) +# l = q.instances(selectquery.execute()) + self.assert_result(l, User, *user_address_result) + def testorderby_desc(self): m = mapper(Address, addresses) -- 2.47.2