it (regression from 0.4).
- Removed the "raiseerror" keyword argument from object_mapper()
and class_mapper(). These functions raise in all cases
if the given class/instance is not mapped.
- Refined ExtensionCarrier to be itself a dict, removed
'methods' accessor
- moved identity_key tests to test/orm/utils.py
- some docstrings
for zero length slices, slices with None on either end.
[ticket:1177]
+ - Added an example illustrating Celko's "nested sets" as a
+ SQLA mapping.
+
+ - session.execute() will execute a Sequence object passed to
+ it (regression from 0.4).
+
+ - Removed the "raiseerror" keyword argument from object_mapper()
+ and class_mapper(). These functions raise in all cases
+ if the given class/instance is not mapped.
- sql
- column.in_(someselect) can now be used as a columns-clause
expression without the subquery bleeding into the FROM clause
class_mapper(cls)
if cls.__dict__.get(clskey) is self:
- # FIXME: there should not be any scenarios where
- # a mapper compile leaves this CompileOnAttr in
- # place.
+ # if this warning occurs, it usually means mapper
+ # compilation has failed, but operations upon the mapped
+ # classes have proceeded.
util.warn(
("Attribute '%s' on class '%s' was not replaced during "
"mapper compilation operation") % (clskey, cls.__name__))
for c in columns:
mc = self.mapped_table.corresponding_column(c)
if not mc:
- raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
+ raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. "
+ "Use the `column_property()` function to force this column "
+ "to be mapped as a read-only attribute." % c)
mapped_column.append(mc)
prop = ColumnProperty(*mapped_column)
else:
- raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%s'. To resolve this, map the column to the class under a different name in the 'properties' dictionary. Or, to remove all awareness of the column entirely (including its availability as a foreign key), use the 'include_properties' or 'exclude_properties' mapper arguments to control specifically which table columns get mapped." % (column.key, repr(prop)))
+ raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%r'. "
+ "To resolve this, map the column to the class under a different "
+ "name in the 'properties' dictionary. Or, to remove all awareness "
+ "of the column entirely (including its availability as a foreign key), "
+ "use the 'include_properties' or 'exclude_properties' mapper arguments "
+ "to control specifically which table columns get mapped." % (column.key, prop))
if isinstance(prop, ColumnProperty):
col = self.mapped_table.corresponding_column(prop.columns[0])
for name in method.__sa_validators__:
self._validators[name] = method
- if 'reconstruct_instance' in self.extension.methods:
+ if 'reconstruct_instance' in self.extension:
def reconstruct(instance):
self.extension.reconstruct_instance(self, instance)
event_registry.add_listener('on_load', reconstruct)
# call before_XXX extensions
for state, mapper, connection, has_identity in tups:
if not has_identity:
- if 'before_insert' in mapper.extension.methods:
+ if 'before_insert' in mapper.extension:
mapper.extension.before_insert(mapper, connection, state.obj())
else:
- if 'before_update' in mapper.extension.methods:
+ if 'before_update' in mapper.extension:
mapper.extension.before_update(mapper, connection, state.obj())
for state, mapper, connection, has_identity in tups:
# call after_XXX extensions
if not has_identity:
- if 'after_insert' in mapper.extension.methods:
+ if 'after_insert' in mapper.extension:
mapper.extension.after_insert(mapper, connection, state.obj())
else:
- if 'after_update' in mapper.extension.methods:
+ if 'after_update' in mapper.extension:
mapper.extension.after_update(mapper, connection, state.obj())
def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
for state, mapper, connection in tups:
- if 'before_delete' in mapper.extension.methods:
+ if 'before_delete' in mapper.extension:
mapper.extension.before_delete(mapper, connection, state.obj())
table_to_mapper = {}
raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
for state, mapper, connection in tups:
- if 'after_delete' in mapper.extension.methods:
+ if 'after_delete' in mapper.extension:
mapper.extension.after_delete(mapper, connection, state.obj())
def _register_dependencies(self, uowcommit):
if not extension:
extension = self.extension
- translate_row = 'translate_row' in extension.methods
- create_instance = 'create_instance' in extension.methods
- populate_instance = 'populate_instance' in extension.methods
- append_result = 'append_result' in extension.methods
+ translate_row = extension.get('translate_row', None)
+ create_instance = extension.get('create_instance', None)
+ populate_instance = extension.get('populate_instance', None)
+ append_result = extension.get('append_result', None)
populate_existing = context.populate_existing or self.always_refresh
def _instance(row, result):
if translate_row:
- ret = extension.translate_row(self, context, row)
+ ret = translate_row(self, context, row)
if ret is not EXT_CONTINUE:
row = ret
loaded_instance = True
if create_instance:
- instance = extension.create_instance(self, context, row, self.class_)
+ instance = create_instance(self, context, row, self.class_)
if instance is EXT_CONTINUE:
instance = self.class_manager.new_instance()
else:
state.runid = context.runid
context.progress.add(state)
- if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ if not populate_instance or populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
populate_state(state, row, isnew, only_load_props)
else:
attrs = state.unloaded
context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs
- if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ if not populate_instance or populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
populate_state(state, row, isnew, attrs, instancekey=identitykey)
if loaded_instance:
state._run_on_load(instance)
- if result is not None and (not append_result or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
+ if result is not None and (not append_result or append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
result.append(instance)
return instance
instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
# compile() always compiles all mappers
instrumenting_mapper.compile()
- if 'init_instance' in instrumenting_mapper.extension.methods:
+ if 'init_instance' in instrumenting_mapper.extension:
instrumenting_mapper.extension.init_instance(
instrumenting_mapper, instrumenting_mapper.class_,
state.manager.events.original_init,
"""Run init_failed hooks."""
instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
- if 'init_failed' in instrumenting_mapper.extension.methods:
+ if 'init_failed' in instrumenting_mapper.extension:
util.warn_exception(
instrumenting_mapper.extension.init_failed,
instrumenting_mapper, instrumenting_mapper.class_,
else:
filter = None
- custom_rows = single_entity and 'append_result' in self._entities[0].extension.methods
+ custom_rows = single_entity and 'append_result' in self._entities[0].extension
(process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities])
return result.rowcount
-
def _compile_context(self, labels=True):
context = QueryContext(self)
crit = self._adapt_clause(crit, False, False)
context.whereclause = sql.and_(context.whereclause, crit)
- def __log_debug(self, msg):
- self.logger.debug(msg)
-
def __str__(self):
return str(self._compile_context().statement)
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs
from sqlalchemy.orm import (
- EXT_CONTINUE, MapperExtension, class_mapper, object_session,
+ EXT_CONTINUE, MapperExtension, class_mapper, object_session
)
+from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm.session import Session
"""
class query(object):
def __get__(s, instance, owner):
- mapper = class_mapper(owner, raiseerror=False)
- if mapper:
- if query_cls:
- # custom query class
- return query_cls(mapper, session=self.registry())
- else:
- # session's configured query class
- return self.registry().query(mapper)
- else:
+ try:
+ mapper = class_mapper(owner)
+ if mapper:
+ if query_cls:
+ # custom query class
+ return query_cls(mapper, session=self.registry())
+ else:
+ # session's configured query class
+ return self.registry().query(mapper)
+ except orm_exc.UnmappedClassError:
return None
return query()
return util.IdentitySet(self._new.values())
-def _expire_state(state, attribute_names):
- """Stand-alone expire instance function.
-
- Installs a callable with the given instance's _state which will fire off
- when any of the named attributes are accessed; their existing value is
- removed.
-
- If the list is None or blank, the entire instance is expired.
-
- """
- state.expire_attributes(attribute_names)
-
+_expire_state = attributes.InstanceState.expire_attributes
register_attribute = unitofwork.register_attribute
_sessions = weakref.WeakValueDictionary()
def object_session(instance):
"""Return the ``Session`` to which instance belongs, or None."""
+
return _state_session(attributes.instance_state(instance))
def _state_session(state):
class DefaultColumnLoader(LoaderStrategy):
- def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None, active_history=False):
+ def _register_attribute(self, compare_function, copy_function, mutable_scalars,
+ comparator_factory, callable_=None, proxy_property=None, active_history=False):
self.logger.info("%s register managed attribute" % self)
attribute_ext = util.to_list(self.parent_property.extension) or []
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
-class ExtensionCarrier(object):
+class ExtensionCarrier(dict):
"""Fronts an ordered collection of MapperExtension objects.
Bundles multiple MapperExtensions into a unified callable unit,
carrier.after_insert(...args...)
- Also includes a 'methods' dictionary accessor which allows for a quick
- check if a particular method is overridden on any contained
- MapperExtensions.
+ The dictionary interface provides containment for implemented
+ method names mapped to a callable which executes that method
+ for participating extensions.
"""
if not method.startswith('_'))
def __init__(self, extensions=None):
- self.methods = {}
self._extensions = []
for ext in extensions or ():
self.append(ext)
def _register(self, extension):
"""Register callable fronts for overridden interface methods."""
- for method in self.interface:
- if method in self.methods:
- continue
+
+ for method in self.interface.difference(self):
impl = getattr(extension, method, None)
if impl and impl is not getattr(MapperExtension, method):
- self.methods[method] = self._create_do(method)
+ self[method] = self._create_do(method)
def _create_do(self, method):
"""Return a closure that loops over impls of the named method."""
return ret
else:
return EXT_CONTINUE
- try:
- _do.__name__ = method.im_func.func_name
- except:
- pass
+ _do.__name__ = method
return _do
@staticmethod
def __getattr__(self, key):
"""Delegate MapperExtension methods to bundled fronts."""
+
if key not in self.interface:
raise AttributeError(key)
- return self.methods.get(key, self._pass)
+ return self.get(key, self._pass)
class ORMAdapter(sql_util.ColumnAdapter):
def __init__(self, entity, equivalents=None, chain_to=None):
return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs))
def _orm_annotate(element, exclude=None):
+ """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
+
+ Elements within the exclude collection will be cloned but not annotated.
+
+ """
def clone(elem):
if exclude and elem in exclude:
elem = elem._clone()
class _ORMJoin(expression.Join):
-
+ """Extend Join to support ORM constructs as input."""
+
__visit_name__ = expression.Join.__visit_name__
def __init__(self, left, right, onclause=None, isouter=False):
return _ORMJoin(self, right, onclause, True)
def join(left, right, onclause=None, isouter=False):
+ """Produce an inner join between left and right clauses.
+
+ In addition to the interface provided by
+ sqlalchemy.sql.join(), left and right may be mapped
+ classes or AliasedClass instances. The onclause may be a
+ string name of a relation(), or a class-bound descriptor
+ representing a relation.
+
+ """
return _ORMJoin(left, right, onclause, isouter)
def outerjoin(left, right, onclause=None):
+ """Produce a left outer join between left and right clauses.
+
+ In addition to the interface provided by
+ sqlalchemy.sql.outerjoin(), left and right may be mapped
+ classes or AliasedClass instances. The onclause may be a
+ string name of a relation(), or a class-bound descriptor
+ representing a relation.
+
+ """
return _ORMJoin(left, right, onclause, True)
def with_parent(instance, prop):
def _entity_info(entity, compile=True):
+ """Return mapping information given a class, mapper, or AliasedClass.
+
+ Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
+ is an aliased() construct.
+
+ If the given entity is not a mapper, mapped class, or aliased construct,
+ returns None, the entity, False. This is typically used to allow
+ unmapped selectables through.
+
+ """
if isinstance(entity, AliasedClass):
return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
elif _is_mapped_class(entity):
return None, entity, False
def _entity_descriptor(entity, key):
+ """Return attribute/property information given an entity and string name.
+
+ Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
+
+ """
if isinstance(entity, AliasedClass):
desc = getattr(entity, key)
return desc, desc.property
def _state_mapper(state):
return state.manager.mapper
-def object_mapper(object, raiseerror=True):
+def object_mapper(instance):
"""Given an object, return the primary Mapper associated with the object instance.
-
- object
- The object instance.
-
- raiseerror
- Defaults to True: raise an ``InvalidRequestError`` if no mapper can
- be located. If False, return None.
-
+
+ Raises UnmappedInstanceError if no mapping is configured.
+
"""
try:
- state = attributes.instance_state(object)
+ state = attributes.instance_state(instance)
+ if not state.manager.mapper:
+ raise exc.UnmappedInstanceError(instance)
+ return state.manager.mapper
except exc.NO_STATE:
- if not raiseerror:
- return None
- raise exc.UnmappedInstanceError(object)
- return class_mapper(
- type(object), compile=False, raiseerror=raiseerror)
+ raise exc.UnmappedInstanceError(instance)
-def class_mapper(class_, compile=True, raiseerror=True):
+def class_mapper(class_, compile=True):
"""Given a class (or an object), return the primary Mapper associated with the key.
- If no mapper can be located, raises ``InvalidRequestError``.
-
- """
+ Raises UnmappedClassError if no mapping is configured.
+ """
if not isinstance(class_, type):
class_ = type(class_)
try:
raise AttributeError
except exc.NO_STATE:
- if not raiseerror:
- return
raise exc.UnmappedClassError(class_)
+
if compile:
mapper = mapper.compile()
return mapper
return state_str(attributes.instance_state(instance))
def state_str(state):
- """Return a string describing an instance."""
+ """Return a string describing an instance via its InstanceState."""
+
if state is None:
return "None"
else:
- return state.class_.__name__ + "@" + hex(id(state.obj()))
+ return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, (ClauseElement, Operators))
+ global schema
+ if not schema:
+ from sqlalchemy import schema
+ return not isinstance(element, (ClauseElement, Operators, schema.SchemaItem))
def _from_objects(*elements, **kwargs):
return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
import pickle
from sqlalchemy.orm import create_session, sessionmaker, attributes
from testlib import engines, sa, testing, config
-from testlib.sa import Table, Column, Integer, String
+from testlib.sa import Table, Column, Integer, String, Sequence
from testlib.sa.orm import mapper, relation, backref
from testlib.testing import eq_
from engine import _base as engine_base
finally:
c.close()
+ @testing.requires.sequences
+ def test_sequence_execute(self):
+ seq = Sequence("some_sequence")
+ seq.create(testing.db)
+ try:
+ sess = create_session(bind=testing.db)
+ eq_(sess.execute(seq), 1)
+ finally:
+ seq.drop(testing.db)
+
+
@testing.resolve_artifact_names
def test_expunge_cascade(self):
mapper(Address, addresses)
assert s.query(Address).one().id == a.id
assert s.query(User).first() is None
- @testing.resolve_artifact_names
- def test_identity_key_1(self):
- mapper(User, users)
- s = create_session()
- key = s.identity_key(User, 1)
- eq_(key, (User, (1,)))
- key = s.identity_key(User, ident=1)
- eq_(key, (User, (1,)))
-
- @testing.resolve_artifact_names
- def test_identity_key_2(self):
- mapper(User, users)
- s = create_session()
- u = User(name='u1')
- s.add(u)
- s.flush()
- key = s.identity_key(instance=u)
- eq_(key, (User, (u.id,)))
-
- @testing.resolve_artifact_names
- def test_identity_key_3(self):
- mapper(User, users)
- s = create_session()
- row = {users.c.id: 1, users.c.name: "Frank"}
- key = s.identity_key(User, row=row)
- eq_(key, (User, (1,)))
-
@testing.resolve_artifact_names
def test_extension(self):
mapper(User, users)
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.orm import aliased
-from sqlalchemy.orm import mapper
+from sqlalchemy.orm import mapper, create_session
+from orm import _fixtures
+from testlib.testing import eq_
class ExtensionCarrierTest(TestBase):
def test_basic(self):
carrier = util.ExtensionCarrier()
- assert 'translate_row' not in carrier.methods
+ assert 'translate_row' not in carrier
assert carrier.translate_row() is interfaces.EXT_CONTINUE
- assert 'translate_row' not in carrier.methods
+ assert 'translate_row' not in carrier
self.assertRaises(AttributeError, lambda: carrier.snickysnack)
return self.marker
carrier.append(Partial('end'))
- assert 'translate_row' in carrier.methods
+ assert 'translate_row' in carrier
assert carrier.translate_row(None) == 'end'
carrier.push(Partial('front'))
assert carrier.translate_row(None) == 'front'
- assert 'populate_instance' not in carrier.methods
+ assert 'populate_instance' not in carrier
carrier.append(interfaces.MapperExtension)
- assert 'populate_instance' in carrier.methods
+ assert 'populate_instance' in carrier
assert carrier.interface
for m in carrier.interface:
assert_table(Point.left_of(p2), table)
assert_table(alias.left_of(p2), alias_table)
-
+class IdentityKeyTest(_fixtures.FixtureTest):
+ run_inserts = None
+
+ @testing.resolve_artifact_names
+ def test_identity_key_1(self):
+ mapper(User, users)
+
+ key = util.identity_key(User, 1)
+ eq_(key, (User, (1,)))
+ key = util.identity_key(User, ident=1)
+ eq_(key, (User, (1,)))
+
+ @testing.resolve_artifact_names
+ def test_identity_key_2(self):
+ mapper(User, users)
+ s = create_session()
+ u = User(name='u1')
+ s.add(u)
+ s.flush()
+ key = util.identity_key(instance=u)
+ eq_(key, (User, (u.id,)))
+
+ @testing.resolve_artifact_names
+ def test_identity_key_3(self):
+ mapper(User, users)
+
+ row = {users.c.id: 1, users.c.name: "Frank"}
+ key = util.identity_key(User, row=row)
+ eq_(key, (User, (1,)))
+
if __name__ == '__main__':
testenv.main()