From: Mike Bayer Date: Sun, 28 Sep 2008 19:10:22 +0000 (+0000) Subject: - session.execute() will execute a Sequence object passed to X-Git-Tag: rel_0_5rc2~24 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6122c3ad4802040a885f61ae97ecac03605057b4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - session.execute() will execute a Sequence object passed to it (regression from 0.4). - Removed the "raiseerror" keyword argument from object_mapper() and class_mapper(). These functions raise in all cases if the given class/instance is not mapped. - Refined ExtensionCarrier to be itself a dict, removed 'methods' accessor - moved identity_key tests to test/orm/utils.py - some docstrings --- diff --git a/CHANGES b/CHANGES index 3686034bcd..3f981b927d 100644 --- a/CHANGES +++ b/CHANGES @@ -31,6 +31,15 @@ CHANGES for zero length slices, slices with None on either end. [ticket:1177] + - Added an example illustrating Celko's "nested sets" as a + SQLA mapping. + + - session.execute() will execute a Sequence object passed to + it (regression from 0.4). + + - Removed the "raiseerror" keyword argument from object_mapper() + and class_mapper(). These functions raise in all cases + if the given class/instance is not mapped. - sql - column.in_(someselect) can now be used as a columns-clause expression without the subquery bleeding into the FROM clause diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 886fcec91e..dd30bdda7a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -630,9 +630,9 @@ class Mapper(object): class_mapper(cls) if cls.__dict__.get(clskey) is self: - # FIXME: there should not be any scenarios where - # a mapper compile leaves this CompileOnAttr in - # place. + # if this warning occurs, it usually means mapper + # compilation has failed, but operations upon the mapped + # classes have proceeded. util.warn( ("Attribute '%s' on class '%s' was not replaced during " "mapper compilation operation") % (clskey, cls.__name__)) @@ -755,11 +755,18 @@ class Mapper(object): for c in columns: mc = self.mapped_table.corresponding_column(c) if not mc: - raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) + raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. " + "Use the `column_property()` function to force this column " + "to be mapped as a read-only attribute." % c) mapped_column.append(mc) prop = ColumnProperty(*mapped_column) else: - raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%s'. To resolve this, map the column to the class under a different name in the 'properties' dictionary. Or, to remove all awareness of the column entirely (including its availability as a foreign key), use the 'include_properties' or 'exclude_properties' mapper arguments to control specifically which table columns get mapped." % (column.key, repr(prop))) + raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%r'. " + "To resolve this, map the column to the class under a different " + "name in the 'properties' dictionary. Or, to remove all awareness " + "of the column entirely (including its availability as a foreign key), " + "use the 'include_properties' or 'exclude_properties' mapper arguments " + "to control specifically which table columns get mapped." % (column.key, prop)) if isinstance(prop, ColumnProperty): col = self.mapped_table.corresponding_column(prop.columns[0]) @@ -879,7 +886,7 @@ class Mapper(object): for name in method.__sa_validators__: self._validators[name] = method - if 'reconstruct_instance' in self.extension.methods: + if 'reconstruct_instance' in self.extension: def reconstruct(instance): self.extension.reconstruct_instance(self, instance) event_registry.add_listener('on_load', reconstruct) @@ -1067,10 +1074,10 @@ class Mapper(object): # call before_XXX extensions for state, mapper, connection, has_identity in tups: if not has_identity: - if 'before_insert' in mapper.extension.methods: + if 'before_insert' in mapper.extension: mapper.extension.before_insert(mapper, connection, state.obj()) else: - if 'before_update' in mapper.extension.methods: + if 'before_update' in mapper.extension: mapper.extension.before_update(mapper, connection, state.obj()) for state, mapper, connection, has_identity in tups: @@ -1237,10 +1244,10 @@ class Mapper(object): # call after_XXX extensions if not has_identity: - if 'after_insert' in mapper.extension.methods: + if 'after_insert' in mapper.extension: mapper.extension.after_insert(mapper, connection, state.obj()) else: - if 'after_update' in mapper.extension.methods: + if 'after_update' in mapper.extension: mapper.extension.after_update(mapper, connection, state.obj()) def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): @@ -1289,7 +1296,7 @@ class Mapper(object): tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)] for state, mapper, connection in tups: - if 'before_delete' in mapper.extension.methods: + if 'before_delete' in mapper.extension: mapper.extension.before_delete(mapper, connection, state.obj()) table_to_mapper = {} @@ -1326,7 +1333,7 @@ class Mapper(object): raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) for state, mapper, connection in tups: - if 'after_delete' in mapper.extension.methods: + if 'after_delete' in mapper.extension: mapper.extension.after_delete(mapper, connection, state.obj()) def _register_dependencies(self, uowcommit): @@ -1419,15 +1426,15 @@ class Mapper(object): if not extension: extension = self.extension - translate_row = 'translate_row' in extension.methods - create_instance = 'create_instance' in extension.methods - populate_instance = 'populate_instance' in extension.methods - append_result = 'append_result' in extension.methods + translate_row = extension.get('translate_row', None) + create_instance = extension.get('create_instance', None) + populate_instance = extension.get('populate_instance', None) + append_result = extension.get('append_result', None) populate_existing = context.populate_existing or self.always_refresh def _instance(row, result): if translate_row: - ret = extension.translate_row(self, context, row) + ret = translate_row(self, context, row) if ret is not EXT_CONTINUE: row = ret @@ -1489,7 +1496,7 @@ class Mapper(object): loaded_instance = True if create_instance: - instance = extension.create_instance(self, context, row, self.class_) + instance = create_instance(self, context, row, self.class_) if instance is EXT_CONTINUE: instance = self.class_manager.new_instance() else: @@ -1517,7 +1524,7 @@ class Mapper(object): state.runid = context.runid context.progress.add(state) - if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + if not populate_instance or populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: populate_state(state, row, isnew, only_load_props) else: @@ -1533,13 +1540,13 @@ class Mapper(object): attrs = state.unloaded context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs - if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + if not populate_instance or populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: populate_state(state, row, isnew, attrs, instancekey=identitykey) if loaded_instance: state._run_on_load(instance) - if result is not None and (not append_result or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): + if result is not None and (not append_result or append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): result.append(instance) return instance @@ -1651,7 +1658,7 @@ def _event_on_init(state, instance, args, kwargs): instrumenting_mapper = state.manager.info[_INSTRUMENTOR] # compile() always compiles all mappers instrumenting_mapper.compile() - if 'init_instance' in instrumenting_mapper.extension.methods: + if 'init_instance' in instrumenting_mapper.extension: instrumenting_mapper.extension.init_instance( instrumenting_mapper, instrumenting_mapper.class_, state.manager.events.original_init, @@ -1661,7 +1668,7 @@ def _event_on_init_failure(state, instance, args, kwargs): """Run init_failed hooks.""" instrumenting_mapper = state.manager.info[_INSTRUMENTOR] - if 'init_failed' in instrumenting_mapper.extension.methods: + if 'init_failed' in instrumenting_mapper.extension: util.warn_exception( instrumenting_mapper.extension.init_failed, instrumenting_mapper, instrumenting_mapper.class_, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index e25716316a..4b6aa0018b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1109,7 +1109,7 @@ class Query(object): else: filter = None - custom_rows = single_entity and 'append_result' in self._entities[0].extension.methods + custom_rows = single_entity and 'append_result' in self._entities[0].extension (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities]) @@ -1433,7 +1433,6 @@ class Query(object): return result.rowcount - def _compile_context(self, labels=True): context = QueryContext(self) @@ -1563,9 +1562,6 @@ class Query(object): crit = self._adapt_clause(crit, False, False) context.whereclause = sql.and_(context.whereclause, crit) - def __log_debug(self, msg): - self.logger.debug(msg) - def __str__(self): return str(self._compile_context().statement) diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 6ea28148d9..6e33be96bf 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -7,8 +7,9 @@ import sqlalchemy.exceptions as sa_exc from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs from sqlalchemy.orm import ( - EXT_CONTINUE, MapperExtension, class_mapper, object_session, + EXT_CONTINUE, MapperExtension, class_mapper, object_session ) +from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm.session import Session @@ -102,15 +103,16 @@ class ScopedSession(object): """ class query(object): def __get__(s, instance, owner): - mapper = class_mapper(owner, raiseerror=False) - if mapper: - if query_cls: - # custom query class - return query_cls(mapper, session=self.registry()) - else: - # session's configured query class - return self.registry().query(mapper) - else: + try: + mapper = class_mapper(owner) + if mapper: + if query_cls: + # custom query class + return query_cls(mapper, session=self.registry()) + else: + # session's configured query class + return self.registry().query(mapper) + except orm_exc.UnmappedClassError: return None return query() diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ad987430a5..0ca59141e0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1535,18 +1535,7 @@ class Session(object): return util.IdentitySet(self._new.values()) -def _expire_state(state, attribute_names): - """Stand-alone expire instance function. - - Installs a callable with the given instance's _state which will fire off - when any of the named attributes are accessed; their existing value is - removed. - - If the list is None or blank, the entire instance is expired. - - """ - state.expire_attributes(attribute_names) - +_expire_state = attributes.InstanceState.expire_attributes register_attribute = unitofwork.register_attribute _sessions = weakref.WeakValueDictionary() @@ -1591,6 +1580,7 @@ def _state_for_unknown_persistence_instance(instance): def object_session(instance): """Return the ``Session`` to which instance belongs, or None.""" + return _state_session(attributes.instance_state(instance)) def _state_session(state): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c1d93153e8..308266cd83 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -20,7 +20,8 @@ from sqlalchemy.orm import util as mapperutil class DefaultColumnLoader(LoaderStrategy): - def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None, active_history=False): + def _register_attribute(self, compare_function, copy_function, mutable_scalars, + comparator_factory, callable_=None, proxy_property=None, active_history=False): self.logger.info("%s register managed attribute" % self) attribute_ext = util.to_list(self.parent_property.extension) or [] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 453b9f510d..f461c4bdcb 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -170,7 +170,7 @@ def identity_key(*args, **kwargs): mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) -class ExtensionCarrier(object): +class ExtensionCarrier(dict): """Fronts an ordered collection of MapperExtension objects. Bundles multiple MapperExtensions into a unified callable unit, @@ -179,9 +179,9 @@ class ExtensionCarrier(object): carrier.after_insert(...args...) - Also includes a 'methods' dictionary accessor which allows for a quick - check if a particular method is overridden on any contained - MapperExtensions. + The dictionary interface provides containment for implemented + method names mapped to a callable which executes that method + for participating extensions. """ @@ -189,7 +189,6 @@ class ExtensionCarrier(object): if not method.startswith('_')) def __init__(self, extensions=None): - self.methods = {} self._extensions = [] for ext in extensions or (): self.append(ext) @@ -213,12 +212,11 @@ class ExtensionCarrier(object): def _register(self, extension): """Register callable fronts for overridden interface methods.""" - for method in self.interface: - if method in self.methods: - continue + + for method in self.interface.difference(self): impl = getattr(extension, method, None) if impl and impl is not getattr(MapperExtension, method): - self.methods[method] = self._create_do(method) + self[method] = self._create_do(method) def _create_do(self, method): """Return a closure that loops over impls of the named method.""" @@ -230,10 +228,7 @@ class ExtensionCarrier(object): return ret else: return EXT_CONTINUE - try: - _do.__name__ = method.im_func.func_name - except: - pass + _do.__name__ = method return _do @staticmethod @@ -242,9 +237,10 @@ class ExtensionCarrier(object): def __getattr__(self, key): """Delegate MapperExtension methods to bundled fronts.""" + if key not in self.interface: raise AttributeError(key) - return self.methods.get(key, self._pass) + return self.get(key, self._pass) class ORMAdapter(sql_util.ColumnAdapter): def __init__(self, entity, equivalents=None, chain_to=None): @@ -320,6 +316,11 @@ class AliasedComparator(PropComparator): return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs)) def _orm_annotate(element, exclude=None): + """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. + + Elements within the exclude collection will be cloned but not annotated. + + """ def clone(elem): if exclude and elem in exclude: elem = elem._clone() @@ -334,7 +335,8 @@ def _orm_annotate(element, exclude=None): class _ORMJoin(expression.Join): - + """Extend Join to support ORM constructs as input.""" + __visit_name__ = expression.Join.__visit_name__ def __init__(self, left, right, onclause=None, isouter=False): @@ -387,9 +389,27 @@ class _ORMJoin(expression.Join): return _ORMJoin(self, right, onclause, True) def join(left, right, onclause=None, isouter=False): + """Produce an inner join between left and right clauses. + + In addition to the interface provided by + sqlalchemy.sql.join(), left and right may be mapped + classes or AliasedClass instances. The onclause may be a + string name of a relation(), or a class-bound descriptor + representing a relation. + + """ return _ORMJoin(left, right, onclause, isouter) def outerjoin(left, right, onclause=None): + """Produce a left outer join between left and right clauses. + + In addition to the interface provided by + sqlalchemy.sql.outerjoin(), left and right may be mapped + classes or AliasedClass instances. The onclause may be a + string name of a relation(), or a class-bound descriptor + representing a relation. + + """ return _ORMJoin(left, right, onclause, True) def with_parent(instance, prop): @@ -417,6 +437,16 @@ def with_parent(instance, prop): def _entity_info(entity, compile=True): + """Return mapping information given a class, mapper, or AliasedClass. + + Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this + is an aliased() construct. + + If the given entity is not a mapper, mapped class, or aliased construct, + returns None, the entity, False. This is typically used to allow + unmapped selectables through. + + """ if isinstance(entity, AliasedClass): return entity._AliasedClass__mapper, entity._AliasedClass__alias, True elif _is_mapped_class(entity): @@ -432,6 +462,11 @@ def _entity_info(entity, compile=True): return None, entity, False def _entity_descriptor(entity, key): + """Return attribute/property information given an entity and string name. + + Returns a 2-tuple representing InstrumentedAttribute/MapperProperty. + + """ if isinstance(entity, AliasedClass): desc = getattr(entity, key) return desc, desc.property @@ -459,33 +494,26 @@ def _is_aliased_class(entity): def _state_mapper(state): return state.manager.mapper -def object_mapper(object, raiseerror=True): +def object_mapper(instance): """Given an object, return the primary Mapper associated with the object instance. - - object - The object instance. - - raiseerror - Defaults to True: raise an ``InvalidRequestError`` if no mapper can - be located. If False, return None. - + + Raises UnmappedInstanceError if no mapping is configured. + """ try: - state = attributes.instance_state(object) + state = attributes.instance_state(instance) + if not state.manager.mapper: + raise exc.UnmappedInstanceError(instance) + return state.manager.mapper except exc.NO_STATE: - if not raiseerror: - return None - raise exc.UnmappedInstanceError(object) - return class_mapper( - type(object), compile=False, raiseerror=raiseerror) + raise exc.UnmappedInstanceError(instance) -def class_mapper(class_, compile=True, raiseerror=True): +def class_mapper(class_, compile=True): """Given a class (or an object), return the primary Mapper associated with the key. - If no mapper can be located, raises ``InvalidRequestError``. - - """ + Raises UnmappedClassError if no mapping is configured. + """ if not isinstance(class_, type): class_ = type(class_) try: @@ -497,9 +525,8 @@ def class_mapper(class_, compile=True, raiseerror=True): raise AttributeError except exc.NO_STATE: - if not raiseerror: - return raise exc.UnmappedClassError(class_) + if compile: mapper = mapper.compile() return mapper @@ -538,11 +565,12 @@ def instance_str(instance): return state_str(attributes.instance_state(instance)) def state_str(state): - """Return a string describing an instance.""" + """Return a string describing an instance via its InstanceState.""" + if state is None: return "None" else: - return state.class_.__name__ + "@" + hex(id(state.obj())) + return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj())) def attribute_str(instance, attribute): return instance_str(instance) + "." + attribute diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 81c9d39363..8f611c05e3 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -880,7 +880,10 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, (ClauseElement, Operators)) + global schema + if not schema: + from sqlalchemy import schema + return not isinstance(element, (ClauseElement, Operators, schema.SchemaItem)) def _from_objects(*elements, **kwargs): return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements]) diff --git a/test/orm/session.py b/test/orm/session.py index 8dd98512e3..f6cd2cc91a 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -4,7 +4,7 @@ import inspect import pickle from sqlalchemy.orm import create_session, sessionmaker, attributes from testlib import engines, sa, testing, config -from testlib.sa import Table, Column, Integer, String +from testlib.sa import Table, Column, Integer, String, Sequence from testlib.sa.orm import mapper, relation, backref from testlib.testing import eq_ from engine import _base as engine_base @@ -62,6 +62,17 @@ class SessionTest(_fixtures.FixtureTest): finally: c.close() + @testing.requires.sequences + def test_sequence_execute(self): + seq = Sequence("some_sequence") + seq.create(testing.db) + try: + sess = create_session(bind=testing.db) + eq_(sess.execute(seq), 1) + finally: + seq.drop(testing.db) + + @testing.resolve_artifact_names def test_expunge_cascade(self): mapper(Address, addresses) @@ -842,33 +853,6 @@ class SessionTest(_fixtures.FixtureTest): assert s.query(Address).one().id == a.id assert s.query(User).first() is None - @testing.resolve_artifact_names - def test_identity_key_1(self): - mapper(User, users) - s = create_session() - key = s.identity_key(User, 1) - eq_(key, (User, (1,))) - key = s.identity_key(User, ident=1) - eq_(key, (User, (1,))) - - @testing.resolve_artifact_names - def test_identity_key_2(self): - mapper(User, users) - s = create_session() - u = User(name='u1') - s.add(u) - s.flush() - key = s.identity_key(instance=u) - eq_(key, (User, (u.id,))) - - @testing.resolve_artifact_names - def test_identity_key_3(self): - mapper(User, users) - s = create_session() - row = {users.c.id: 1, users.c.name: "Frank"} - key = s.identity_key(User, row=row) - eq_(key, (User, (1,))) - @testing.resolve_artifact_names def test_extension(self): mapper(User, users) diff --git a/test/orm/utils.py b/test/orm/utils.py index 1f2cbe13a7..0a449fbf70 100644 --- a/test/orm/utils.py +++ b/test/orm/utils.py @@ -7,16 +7,18 @@ from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import Table from sqlalchemy.orm import aliased -from sqlalchemy.orm import mapper +from sqlalchemy.orm import mapper, create_session +from orm import _fixtures +from testlib.testing import eq_ class ExtensionCarrierTest(TestBase): def test_basic(self): carrier = util.ExtensionCarrier() - assert 'translate_row' not in carrier.methods + assert 'translate_row' not in carrier assert carrier.translate_row() is interfaces.EXT_CONTINUE - assert 'translate_row' not in carrier.methods + assert 'translate_row' not in carrier self.assertRaises(AttributeError, lambda: carrier.snickysnack) @@ -27,15 +29,15 @@ class ExtensionCarrierTest(TestBase): return self.marker carrier.append(Partial('end')) - assert 'translate_row' in carrier.methods + assert 'translate_row' in carrier assert carrier.translate_row(None) == 'end' carrier.push(Partial('front')) assert carrier.translate_row(None) == 'front' - assert 'populate_instance' not in carrier.methods + assert 'populate_instance' not in carrier carrier.append(interfaces.MapperExtension) - assert 'populate_instance' in carrier.methods + assert 'populate_instance' in carrier assert carrier.interface for m in carrier.interface: @@ -201,8 +203,37 @@ class AliasedClassTest(TestBase): assert_table(Point.left_of(p2), table) assert_table(alias.left_of(p2), alias_table) - +class IdentityKeyTest(_fixtures.FixtureTest): + run_inserts = None + + @testing.resolve_artifact_names + def test_identity_key_1(self): + mapper(User, users) + + key = util.identity_key(User, 1) + eq_(key, (User, (1,))) + key = util.identity_key(User, ident=1) + eq_(key, (User, (1,))) + + @testing.resolve_artifact_names + def test_identity_key_2(self): + mapper(User, users) + s = create_session() + u = User(name='u1') + s.add(u) + s.flush() + key = util.identity_key(instance=u) + eq_(key, (User, (u.id,))) + + @testing.resolve_artifact_names + def test_identity_key_3(self): + mapper(User, users) + + row = {users.c.id: 1, users.c.name: "Frank"} + key = util.identity_key(User, row=row) + eq_(key, (User, (1,))) + if __name__ == '__main__': testenv.main()