From: Mike Bayer Date: Sun, 26 Jul 2009 18:58:54 +0000 (+0000) Subject: merged -r6172:6204 of trunk X-Git-Tag: rel_0_6_6~67 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1f1256526e9c82c56dfdee80f7080cb320abd528;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merged -r6172:6204 of trunk --- diff --git a/CHANGES b/CHANGES index 78f426aaa4..0a4c34b3d5 100644 --- a/CHANGES +++ b/CHANGES @@ -16,6 +16,35 @@ CHANGES during a flush. This is currently to support many-to-many relations from concrete inheritance setups. Outside of that use case, YMMV. [ticket:1477] + + - Squeezed a few more unnecessary "lazy loads" out of + relation(). When a collection is mutated, many-to-one + backrefs on the other side will not fire off to load + the "old" value, unless "single_parent=True" is set. + A direct assignment of a many-to-one still loads + the "old" value in order to update backref collections + on that value, which may be present in the session + already, thus maintaining the 0.5 behavioral contract. + [ticket:1483] + + - Fixed bug whereby a load/refresh of joined table + inheritance attributes which were based on + column_property() or similar would fail to evaluate. + [ticket:1480] + + - Improved error message when query() is called with + a non-SQL /entity expression. [ticket:1476] + + - Using False or 0 as a polymorphic discriminator now + works on the base class as well as a subclass. + [ticket:1440] + + - Added enable_assertions(False) to Query which disables + the usual assertions for expected state - used + by Query subclasses to engineer custom state. + [ticket:1424]. See + http://www.sqlalchemy.org/trac/wiki/UsageRecipes/PreFilteredQuery + for an example. - sql - Fixed a bug in extract() introduced in 0.5.4 whereby @@ -23,6 +52,24 @@ CHANGES ClauseElement, causing various errors within more complex SQL transformations. + - Unary expressions such as DISTINCT propagate their + type handling to result sets, allowing conversions like + unicode and such to take place. [ticket:1420] + + - Fixed bug in Table and Column whereby passing empty + dict for "info" argument would raise an exception. + [ticket:1482] + +- ext + - The collection proxies produced by associationproxy are now + pickleable. A user-defined proxy_factory however + is still not pickleable unless it defines __getstate__ + and __setstate__. [ticket:1446] + + - Declarative will raise an informative exception if + __table_args__ is passed as a tuple with no dict argument. + Improved documentation. [ticket:1468] + 0.5.5 ======= - general diff --git a/doc/build/ormtutorial.rst b/doc/build/ormtutorial.rst index acdbba149e..c10d457f14 100644 --- a/doc/build/ormtutorial.rst +++ b/doc/build/ormtutorial.rst @@ -567,6 +567,54 @@ To use an entirely string-based statement, using ``from_statement()``; just ensu ['ed'] {stop}[] +Counting +-------- + +``Query`` includes a convenience method for counting called ``count()``: + +.. sourcecode:: python+sql + + {sql}>>> session.query(User).filter(User.name.like('%ed')).count() + SELECT count(1) AS count_1 + FROM users + WHERE users.name LIKE ? + ['%ed'] + {stop}2 + +The ``count()`` method is used to determine how many rows the SQL statement would return, and is mainly intended to return a simple count of a single type of entity, in this case ``User``. For more complicated sets of columns or entities where the "thing to be counted" needs to be indicated more specifically, ``count()`` is probably not what you want. Below, a query for individual columns does return the expected result: + +.. sourcecode:: python+sql + + {sql}>>> session.query(User.id, User.name).filter(User.name.like('%ed')).count() + SELECT count(1) AS count_1 + FROM (SELECT users.id AS users_id, users.name AS users_name + FROM users + WHERE users.name LIKE ?) AS anon_1 + ['%ed'] + {stop}2 + +...but if you look at the generated SQL, SQLAlchemy saw that we were placing individual column expressions and decided to wrap whatever it was we were doing in a subquery, so as to be assured that it returns the "number of rows". This defensive behavior is not really needed here and in other cases is not what we want at all, such as if we wanted a grouping of counts per name: + +.. sourcecode:: python+sql + + {sql}>>> session.query(User.name).group_by(User.name).count() + SELECT count(1) AS count_1 + FROM (SELECT users.name AS users_name + FROM users GROUP BY users.name) AS anon_1 + [] + {stop}4 + +We don't want the number ``4``, we wanted some rows back. So for detailed queries where you need to count something specific, use the ``func.count()`` function as a column expression: + +.. sourcecode:: python+sql + + >>> from sqlalchemy import func + {sql}>>> session.query(func.count(User.name), User.name).group_by(User.name).all() + SELECT count(users.name) AS count_1, users.name AS users_name + FROM users GROUP BY users.name + {stop}[] + [(1, u'ed'), (1, u'fred'), (1, u'mary'), (1, u'wendy')] + Building a Relation ==================== @@ -824,7 +872,7 @@ The ``Query`` is suitable for generating statements which can be used as subquer (SELECT user_id, count(*) AS address_count FROM addresses GROUP BY user_id) AS adr_count ON users.id=adr_count.user_id -Using the ``Query``, we build a statement like this from the inside out. The ``statement`` accessor returns a SQL expression representing the statement generated by a particular ``Query`` - this is an instance of a ``select()`` construct, which are described in :ref:`sql`:: +Using the ``Query``, we build a statement like this from the inside out. The ``statement`` accessor returns a SQL expression representing the statement generated by a particular ``Query`` - this is an instance of a ``select()`` construct, which are described in :ref:`sqlexpression_toplevel`:: >>> from sqlalchemy.sql import func >>> stmt = session.query(Address.user_id, func.count('*').label('address_count')).group_by(Address.user_id).subquery() diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 315142d8e0..e126fe638d 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -140,26 +140,14 @@ class AssociationProxy(object): return (orm.class_mapper(self.owning_class). get_property(self.target_collection)) + @property def target_class(self): """The class the proxy is attached to.""" return self._get_property().mapper.class_ - target_class = property(target_class) def _target_is_scalar(self): return not self._get_property().uselist - def _lazy_collection(self, weakobjref): - target = self.target_collection - del self - def lazy_collection(): - obj = weakobjref() - if obj is None: - raise exceptions.InvalidRequestError( - "stale association proxy, parent object has gone out of " - "scope") - return getattr(obj, target) - return lazy_collection - def __get__(self, obj, class_): if self.owning_class is None: self.owning_class = class_ and class_ or type(obj) @@ -181,10 +169,10 @@ class AssociationProxy(object): return proxy except AttributeError: pass - proxy = self._new(self._lazy_collection(weakref.ref(obj))) + proxy = self._new(_lazy_collection(obj, self.target_collection)) setattr(obj, self.key, (id(obj), proxy)) return proxy - + def __set__(self, obj, values): if self.owning_class is None: self.owning_class = type(obj) @@ -238,13 +226,13 @@ class AssociationProxy(object): getter, setter = self.getset_factory(self.collection_class, self) else: getter, setter = self._default_getset(self.collection_class) - + if self.collection_class is list: - return _AssociationList(lazy_collection, creator, getter, setter) + return _AssociationList(lazy_collection, creator, getter, setter, self) elif self.collection_class is dict: - return _AssociationDict(lazy_collection, creator, getter, setter) + return _AssociationDict(lazy_collection, creator, getter, setter, self) elif self.collection_class is set: - return _AssociationSet(lazy_collection, creator, getter, setter) + return _AssociationSet(lazy_collection, creator, getter, setter, self) else: raise exceptions.ArgumentError( 'could not guess which interface to use for ' @@ -252,6 +240,18 @@ class AssociationProxy(object): 'proxy_factory and proxy_bulk_set manually' % (self.collection_class.__name__, self.target_collection)) + def _inflate(self, proxy): + creator = self.creator and self.creator or self.target_class + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + def _set(self, proxy, values): if self.proxy_bulk_set: self.proxy_bulk_set(proxy, values) @@ -266,12 +266,32 @@ class AssociationProxy(object): 'no proxy_bulk_set supplied for custom ' 'collection_class implementation') +class _lazy_collection(object): + def __init__(self, obj, target): + self.ref = weakref.ref(obj) + self.target = target -class _AssociationList(object): - """Generic, converting, list-to-list proxy.""" - - def __init__(self, lazy_collection, creator, getter, setter): - """Constructs an _AssociationList. + def __call__(self): + obj = self.ref() + if obj is None: + raise exceptions.InvalidRequestError( + "stale association proxy, parent object has gone out of " + "scope") + return getattr(obj, self.target) + + def __getstate__(self): + return {'obj':self.ref(), 'target':self.target} + + def __setstate__(self, state): + self.ref = weakref.ref(state['obj']) + self.target = state['target'] + +class _AssociationCollection(object): + def __init__(self, lazy_collection, creator, getter, setter, parent): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. lazy_collection A callable returning a list-based collection of entities (usually an @@ -296,9 +316,27 @@ class _AssociationList(object): self.creator = creator self.getter = getter self.setter = setter + self.parent = parent col = property(lambda self: self.lazy_collection()) + def __len__(self): + return len(self.col) + + def __nonzero__(self): + return bool(self.col) + + def __getstate__(self): + return {'parent':self.parent, 'lazy_collection':self.lazy_collection} + + def __setstate__(self, state): + self.parent = state['parent'] + self.lazy_collection = state['lazy_collection'] + self.parent._inflate(self) + +class _AssociationList(_AssociationCollection): + """Generic, converting, list-to-list proxy.""" + def _create(self, value): return self.creator(value) @@ -308,15 +346,6 @@ class _AssociationList(object): def _set(self, object, value): return self.setter(object, value) - def __len__(self): - return len(self.col) - - def __nonzero__(self): - if self.col: - return True - else: - return False - def __getitem__(self, index): return self._get(self.col[index]) @@ -494,39 +523,9 @@ class _AssociationList(object): _NotProvided = util.symbol('_NotProvided') -class _AssociationDict(object): +class _AssociationDict(_AssociationCollection): """Generic, converting, dict-to-dict proxy.""" - def __init__(self, lazy_collection, creator, getter, setter): - """Constructs an _AssociationDict. - - lazy_collection - A callable returning a dict-based collection of entities (usually an - object attribute managed by a SQLAlchemy relation()) - - creator - A function that creates new target entities. Given two parameters: - key and value. The assertion is assumed:: - - obj = creator(somekey, somevalue) - assert getter(somekey) == somevalue - - getter - A function. Given an associated object and a key, return the - 'value'. - - setter - A function. Given an associated object, a key and a value, store - that value on the object. - - """ - self.lazy_collection = lazy_collection - self.creator = creator - self.getter = getter - self.setter = setter - - col = property(lambda self: self.lazy_collection()) - def _create(self, key, value): return self.creator(key, value) @@ -536,15 +535,6 @@ class _AssociationDict(object): def _set(self, object, key, value): return self.setter(object, key, value) - def __len__(self): - return len(self.col) - - def __nonzero__(self): - if self.col: - return True - else: - return False - def __getitem__(self, key): return self._get(self.col[key]) @@ -669,38 +659,9 @@ class _AssociationDict(object): del func_name, func -class _AssociationSet(object): +class _AssociationSet(_AssociationCollection): """Generic, converting, set-to-set proxy.""" - def __init__(self, lazy_collection, creator, getter, setter): - """Constructs an _AssociationSet. - - collection - A callable returning a set-based collection of entities (usually an - object attribute managed by a SQLAlchemy relation()) - - creator - A function that creates new target entities. Given one parameter: - value. The assertion is assumed:: - - obj = creator(somevalue) - assert getter(obj) == somevalue - - getter - A function. Given an associated object, return the 'value'. - - setter - A function. Given an associated object and a value, store that - value on the object. - - """ - self.lazy_collection = lazy_collection - self.creator = creator - self.getter = getter - self.setter = setter - - col = property(lambda self: self.lazy_collection()) - def _create(self, value): return self.creator(value) diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 43369311b3..c37211ac3d 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -214,29 +214,39 @@ ORM function:: Table Configuration =================== -As an alternative to ``__tablename__``, a direct :class:`~sqlalchemy.schema.Table` construct may be -used. The :class:`~sqlalchemy.schema.Column` objects, which in this case require their names, will be -added to the mapping just like a regular mapping to a table:: +Table arguments other than the name, metadata, and mapped Column arguments +are specified using the ``__table_args__`` class attribute. This attribute +accommodates both positional as well as keyword arguments that are normally +sent to the :class:`~sqlalchemy.schema.Table` constructor. The attribute can be specified +in one of two forms. One is as a dictionary:: class MyClass(Base): - __table__ = Table('my_table', Base.metadata, - Column('id', Integer, primary_key=True), - Column('name', String(50)) - ) + __tablename__ = 'sometable' + __table_args__ = {'mysql_engine':'InnoDB'} -Other table-based attributes include ``__table_args__``, which is -either a dictionary as in:: +The other, a tuple of the form ``(arg1, arg2, ..., {kwarg1:value, ...})``, which +allows positional arguments to be specified as well (usually constraints):: class MyClass(Base): __tablename__ = 'sometable' - __table_args__ = {'mysql_engine':'InnoDB'} - -or a dictionary-containing tuple in the form -``(arg1, arg2, ..., {kwarg1:value, ...})``, as in:: + __table_args__ = ( + ForeignKeyConstraint(['id'], ['remote_table.id']), + UniqueConstraint('foo'), + {'autoload':True} + ) + +Note that the dictionary is required in the tuple form even if empty. + +As an alternative to ``__tablename__``, a direct :class:`~sqlalchemy.schema.Table` +construct may be used. The :class:`~sqlalchemy.schema.Column` objects, which +in this case require their names, will be +added to the mapping just like a regular mapping to a table:: class MyClass(Base): - __tablename__ = 'sometable' - __table_args__ = (ForeignKeyConstraint(['id'], ['remote_table.id']), {'autoload':True}) + __table__ = Table('my_table', Base.metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)) + ) Mapper Configuration ==================== @@ -468,6 +478,11 @@ def _as_declarative(cls, classname, dict_): elif isinstance(table_args, tuple): args = table_args[0:-1] table_kw = table_args[-1] + if len(table_args) < 2 or not isinstance(table_kw, dict): + raise exceptions.ArgumentError( + "Tuple form of __table_args__ is " + "(arg1, arg2, arg3, ..., {'kw1':val1, 'kw2':val2, ...})" + ) else: args, table_kw = (), {} diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 09a05b56fe..f6947dbc11 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -262,7 +262,6 @@ class AttributeImpl(object): active_history indicates that get_history() should always return the "old" value, even if it means executing a lazy callable upon attribute change. - This flag is set to True if any extensions are present. parent_token Usually references the MapperProperty, used as a key for @@ -286,6 +285,10 @@ class AttributeImpl(object): else: self.is_equal = compare_function self.extensions = util.to_list(extension or []) + for e in self.extensions: + if e.active_history: + active_history = True + break self.active_history = active_history self.expire_missing = expire_missing @@ -383,12 +386,12 @@ class AttributeImpl(object): return self.initialize(state, dict_) def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - self.set(state, dict_, value, initiator) + self.set(state, dict_, value, initiator, passive=passive) def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): - self.set(state, dict_, None, initiator) + self.set(state, dict_, None, initiator, passive=passive) - def set(self, state, dict_, value, initiator): + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): raise NotImplementedError() def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): @@ -421,7 +424,7 @@ class ScalarAttributeImpl(AttributeImpl): def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? - if self.active_history or self.extensions: + if self.active_history: old = self.get(state, dict_) else: old = dict_.get(self.key, NO_VALUE) @@ -436,11 +439,11 @@ class ScalarAttributeImpl(AttributeImpl): return History.from_attribute( self, state, dict_.get(self.key, NO_VALUE)) - def set(self, state, dict_, value, initiator): + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - if self.active_history or self.extensions: + if self.active_history: old = self.get(state, dict_) else: old = dict_.get(self.key, NO_VALUE) @@ -511,7 +514,7 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): ScalarAttributeImpl.delete(self, state, dict_) state.mutable_dict.pop(self.key) - def set(self, state, dict_, value, initiator): + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return @@ -559,7 +562,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): else: return History.from_attribute(self, state, current) - def set(self, state, dict_, value, initiator): + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -569,9 +572,21 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): """ if initiator is self: return - - # may want to add options to allow the get() here to be passive - old = self.get(state, dict_) + + if self.active_history: + old = self.get(state, dict_) + else: + # this would be the "laziest" approach, + # however it breaks currently expected backref + # behavior + #old = dict_.get(self.key, None) + # instead, use the "passive" setting, which + # is only going to be PASSIVE_NOCALLABLES if it + # came from a backref + old = self.get(state, dict_, passive=passive) + if old is PASSIVE_NORESULT: + old = None + value = self.fire_replace_event(state, dict_, value, old, initiator) dict_[self.key] = value @@ -707,7 +722,7 @@ class CollectionAttributeImpl(AttributeImpl): else: collection.remove_with_event(value, initiator) - def set(self, state, dict_, value, initiator): + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -808,6 +823,9 @@ class GenericBackrefExtension(interfaces.AttributeExtension): are two objects which contain scalar references to each other. """ + + active_history = False + def __init__(self, key): self.key = key diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 70243291dc..0bc7bab24e 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -100,7 +100,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, dict_, value, initiator): + def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF): if initiator is self: return diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c590f4323f..eaafe5761a 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -783,6 +783,11 @@ class AttributeExtension(object): """ + active_history = True + """indicates that the set() method would like to receive the 'old' value, + even if it means firing lazy callables. + """ + def append(self, state, value, initiator): """Receive a collection append event. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index d155f66d12..bb241b6aeb 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -244,7 +244,7 @@ class Mapper(object): else: self.mapped_table = self.local_table - if self.polymorphic_identity and not self.concrete: + if self.polymorphic_identity is not None and not self.concrete: self._identity_class = self.inherits._identity_class else: self._identity_class = self.class_ @@ -278,7 +278,7 @@ class Mapper(object): self._all_tables = set() self.base_mapper = self self.mapped_table = self.local_table - if self.polymorphic_identity: + if self.polymorphic_identity is not None: self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ @@ -1101,7 +1101,12 @@ class Mapper(object): """ props = self._props - tables = set(props[key].parent.local_table for key in attribute_names) + + tables = set(chain(* + (sqlutil.find_tables(props[key].columns[0], check_columns=True) + for key in attribute_names) + )) + if self.base_mapper.local_table in tables: return None @@ -1138,7 +1143,8 @@ class Mapper(object): return None cond = sql.and_(*allconds) - return sql.select(tables, cond, use_labels=True) + + return sql.select([props[key].columns[0] for key in attribute_names], cond, use_labels=True) def cascade_iterator(self, type_, state, halt_on=None): """Iterate each element and its mapper in an object graph, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 0c4bbc6eb6..0ea67f6e53 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -54,40 +54,41 @@ def _generative(*assertions): return generate class Query(object): - """Encapsulates the object-fetching operations provided by Mappers.""" - + """ORM-level SQL construction object.""" + + _enable_eagerloads = True + _enable_assertions = True + _with_labels = False + _criterion = None + _yield_per = None + _lockmode = None + _order_by = False + _group_by = False + _having = None + _distinct = False + _offset = None + _limit = None + _statement = None + _joinpoint = None + _correlate = frozenset() + _populate_existing = False + _version_check = False + _autoflush = True + _current_path = () + _only_load_props = None + _refresh_state = None + _from_obj = () + _filter_aliases = None + _from_obj_alias = None + _currenttables = frozenset() + def __init__(self, entities, session=None): self.session = session self._with_options = [] - self._lockmode = None - self._order_by = False - self._group_by = False - self._distinct = False - self._offset = None - self._limit = None - self._statement = None self._params = {} - self._yield_per = None - self._criterion = None - self._correlate = set() - self._joinpoint = None - self._with_labels = False - self._enable_eagerloads = True - self.__joinable_tables = None - self._having = None - self._populate_existing = False - self._version_check = False - self._autoflush = True self._attributes = {} - self._current_path = () - self._only_load_props = None - self._refresh_state = None - self._from_obj = () self._polymorphic_adapters = {} - self._filter_aliases = None - self._from_obj_alias = None - self.__currenttables = set() self._set_entities(entities) def _set_entities(self, entities, entity_wrapper=None): @@ -97,9 +98,9 @@ class Query(object): for ent in util.to_list(entities): entity_wrapper(self, ent) - self.__setup_aliasizers(self._entities) + self._setup_aliasizers(self._entities) - def __setup_aliasizers(self, entities): + def _setup_aliasizers(self, entities): if hasattr(self, '_mapper_adapter_map'): # usually safe to share a single map, but copying to prevent # subtle leaks if end-user is reusing base query with arbitrary @@ -114,7 +115,8 @@ class Query(object): mapper, selectable, is_aliased_class = _entity_info(entity) if not is_aliased_class and mapper.with_polymorphic: with_polymorphic = mapper._with_polymorphic_mappers - self.__mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)) + self.__mapper_loads_polymorphically_with(mapper, + sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)) adapter = None elif is_aliased_class: adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns) @@ -131,7 +133,7 @@ class Query(object): for m in m2.iterate_to_root(): self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter - def __set_select_from(self, from_obj): + def _set_select_from(self, from_obj): if isinstance(from_obj, expression._SelectBaseMixin): from_obj = from_obj.alias() @@ -142,7 +144,8 @@ class Query(object): self._from_obj_alias = sql_util.ColumnAdapter(from_obj, equivs) def _get_polymorphic_adapter(self, entity, selectable): - self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) + self.__mapper_loads_polymorphically_with(entity.mapper, + sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) def _reset_polymorphic_adapter(self, mapper): for m2 in mapper._with_polymorphic_mappers: @@ -151,7 +154,7 @@ class Query(object): self._polymorphic_adapters.pop(m.mapped_table, None) self._polymorphic_adapters.pop(m.local_table, None) - def __reset_joinpoint(self): + def _reset_joinpoint(self): self._joinpoint = None self._filter_aliases = None @@ -210,9 +213,17 @@ class Query(object): return clause if getattr(self, '_disable_orm_filtering', not orm_only): - return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters)) + return visitors.replacement_traverse( + clause, + {'column_collections':False}, + self.__replace_element(adapters) + ) else: - return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters)) + return visitors.replacement_traverse( + clause, + {'column_collections':False}, + self.__replace_orm_element(adapters) + ) def _entity_zero(self): return self._entities[0] @@ -243,12 +254,16 @@ class Query(object): def _only_mapper_zero(self, rationale=None): if len(self._entities) > 1: - raise sa_exc.InvalidRequestError(rationale or "This operation requires a Query against a single mapper.") + raise sa_exc.InvalidRequestError( + rationale or "This operation requires a Query against a single mapper." + ) return self._mapper_zero() def _only_entity_zero(self, rationale=None): if len(self._entities) > 1: - raise sa_exc.InvalidRequestError(rationale or "This operation requires a Query against a single mapper.") + raise sa_exc.InvalidRequestError( + rationale or "This operation requires a Query against a single mapper." + ) return self._entity_zero() def _generate_mapper_zero(self): @@ -264,7 +279,9 @@ class Query(object): equivs.update(ent.mapper._equivalent_columns) return equivs - def __no_criterion_condition(self, meth): + def _no_criterion_condition(self, meth): + if not self._enable_assertions: + return if self._criterion or self._statement or self._from_obj or \ self._limit is not None or self._offset is not None or \ self._group_by: @@ -273,28 +290,36 @@ class Query(object): self._from_obj = () self._statement = self._criterion = None self._order_by = self._group_by = self._distinct = False - self.__joined_tables = {} - def __no_clauseelement_condition(self, meth): + def _no_clauseelement_condition(self, meth): + if not self._enable_assertions: + return if self._order_by: raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth) - self.__no_criterion_condition(meth) + self._no_criterion_condition(meth) - def __no_statement_condition(self, meth): + def _no_statement_condition(self, meth): + if not self._enable_assertions: + return if self._statement: raise sa_exc.InvalidRequestError( ("Query.%s() being called on a Query with an existing full " "statement - can't apply criterion.") % meth) - def __no_limit_offset(self, meth): + def _no_limit_offset(self, meth): + if not self._enable_assertions: + return if self._limit is not None or self._offset is not None: - # TODO: do we want from_self() to be implicit here ? i vote explicit for the time being - raise sa_exc.InvalidRequestError("Query.%s() being called on a Query which already has LIMIT or OFFSET applied. " - "To modify the row-limited results of a Query, call from_self() first. Otherwise, call %s() before limit() or offset() are applied." % (meth, meth) + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a Query which already has LIMIT or OFFSET applied. " + "To modify the row-limited results of a Query, call from_self() first. " + "Otherwise, call %s() before limit() or offset() are applied." % (meth, meth) ) - - def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_state=None): + def _get_options(self, populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None): if populate_existing: self._populate_existing = populate_existing if version_check: @@ -315,7 +340,8 @@ class Query(object): def statement(self): """The full SELECT statement represented by this Query.""" - return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True}) + return self._compile_context(labels=self._with_labels).\ + statement._annotate({'_halt_adapt': True}) def subquery(self): """return the full SELECT statement represented by this Query, embedded within an Alias. @@ -358,7 +384,29 @@ class Query(object): """ self._with_labels = True - + + @_generative() + def enable_assertions(self, value): + """Control whether assertions are generated. + + When set to False, the returned Query will + not assert its state before certain operations, + including that LIMIT/OFFSET has not been applied + when filter() is called, no criterion exists + when get() is called, and no "from_statement()" + exists when filter()/order_by()/group_by() etc. + is called. This more permissive mode is used by + custom Query subclasses to specify criterion or + other modifiers outside of the usual usage patterns. + + Care should be taken to ensure that the usage + pattern is even possible. A statement applied + by from_statement() will override any criterion + set by filter() or order_by(), for example. + + """ + self._enable_assertions = value + @property def whereclause(self): """The WHERE criterion for this Query.""" @@ -375,7 +423,7 @@ class Query(object): """ self._current_path = path - @_generative(__no_clauseelement_condition) + @_generative(_no_clauseelement_condition) def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None): """Load columns for descendant mappers of this Query's mapper. @@ -438,7 +486,9 @@ class Query(object): if hasattr(ident, '__composite_values__'): ident = ident.__composite_values__() - key = self._only_mapper_zero("get() can only be used against a single mapped class.").identity_key_from_primary_key(ident) + key = self._only_mapper_zero( + "get() can only be used against a single mapped class." + ).identity_key_from_primary_key(ident) return self._get(key, ident) @classmethod @@ -526,7 +576,11 @@ class Query(object): if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero(): break else: - raise sa_exc.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self._mapper_zero().class_.__name__, instance.__class__.__name__)) + raise sa_exc.InvalidRequestError( + "Could not locate a property which relates instances " + "of class '%s' to instances of class '%s'" % + (self._mapper_zero().class_.__name__, instance.__class__.__name__) + ) else: prop = mapper.get_property(property, resolve_synonyms=True) return self.filter(prop.compare(operators.eq, instance, value_is_parent=True)) @@ -540,7 +594,7 @@ class Query(object): self._entities = list(self._entities) m = _MapperEntity(self, entity) - self.__setup_aliasizers([m]) + self._setup_aliasizers([m]) def from_self(self, *entities): """return a Query that selects from this Query's SELECT statement. @@ -562,7 +616,7 @@ class Query(object): self._statement = self._criterion = None self._order_by = self._group_by = self._distinct = False self._limit = self._offset = None - self.__set_select_from(fromclause) + self._set_select_from(fromclause) def values(self, *columns): """Return an iterator yielding result tuples corresponding to the given list of columns""" @@ -596,20 +650,20 @@ class Query(object): _ColumnEntity(self, column) # _ColumnEntity may add many entities if the # given arg is a FROM clause - self.__setup_aliasizers(self._entities[l:]) + self._setup_aliasizers(self._entities[l:]) def options(self, *args): """Return a new Query object, applying the given list of MapperOptions. """ - return self.__options(False, *args) + return self._options(False, *args) def _conditional_options(self, *args): - return self.__options(True, *args) + return self._options(True, *args) @_generative() - def __options(self, conditional, *args): + def _options(self, conditional, *args): # most MapperOptions write to the '_attributes' dictionary, # so copy that as well self._attributes = self._attributes.copy() @@ -645,7 +699,7 @@ class Query(object): self._params = self._params.copy() self._params.update(kwargs) - @_generative(__no_statement_condition, __no_limit_offset) + @_generative(_no_statement_condition, _no_limit_offset) def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` @@ -674,7 +728,7 @@ class Query(object): return self.filter(sql.and_(*clauses)) - @_generative(__no_statement_condition, __no_limit_offset) + @_generative(_no_statement_condition, _no_limit_offset) @util.accepts_a_list_as_starargs(list_deprecation='pending') def order_by(self, *criterion): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" @@ -689,7 +743,7 @@ class Query(object): else: self._order_by = self._order_by + criterion - @_generative(__no_statement_condition, __no_limit_offset) + @_generative(_no_statement_condition, _no_limit_offset) @util.accepts_a_list_as_starargs(list_deprecation='pending') def group_by(self, *criterion): """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``""" @@ -703,7 +757,7 @@ class Query(object): else: self._group_by = self._group_by + criterion - @_generative(__no_statement_condition, __no_limit_offset) + @_generative(_no_statement_condition, _no_limit_offset) def having(self, criterion): """apply a HAVING criterion to the query and return the newly resulting ``Query``.""" @@ -871,7 +925,7 @@ class Query(object): aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) if kwargs: raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) - return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint) + return self._join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint) @util.accepts_a_list_as_starargs(list_deprecation='pending') def outerjoin(self, *props, **kwargs): @@ -884,19 +938,19 @@ class Query(object): aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) if kwargs: raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) - return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) + return self._join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) - @_generative(__no_statement_condition, __no_limit_offset) - def __join(self, keys, outerjoin, create_aliases, from_joinpoint): + @_generative(_no_statement_condition, _no_limit_offset) + def _join(self, keys, outerjoin, create_aliases, from_joinpoint): # copy collections that may mutate so they do not affect # the copied-from query. - self.__currenttables = set(self.__currenttables) + self._currenttables = set(self._currenttables) self._polymorphic_adapters = self._polymorphic_adapters.copy() # start from the beginning unless from_joinpoint is set. if not from_joinpoint: - self.__reset_joinpoint() + self._reset_joinpoint() clause = replace_clause_index = None @@ -1031,11 +1085,11 @@ class Query(object): elif prop: # for joins across plain relation()s, try not to specify the - # same joins twice. the __currenttables collection tracks + # same joins twice. the _currenttables collection tracks # what plain mapped tables we've joined to already. - if prop.table in self.__currenttables: - if prop.secondary is not None and prop.secondary not in self.__currenttables: + if prop.table in self._currenttables: + if prop.secondary is not None and prop.secondary not in self._currenttables: # TODO: this check is not strong enough for different paths to the same endpoint which # does not use secondary tables raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this " @@ -1044,8 +1098,8 @@ class Query(object): continue if prop.secondary: - self.__currenttables.add(prop.secondary) - self.__currenttables.add(prop.table) + self._currenttables.add(prop.secondary) + self._currenttables.add(prop.table) if of_type: right_entity = of_type @@ -1101,7 +1155,7 @@ class Query(object): # future joins with from_joinpoint=True join from our established right_entity. self._joinpoint = right_entity - @_generative(__no_statement_condition) + @_generative(_no_statement_condition) def reset_joinpoint(self): """return a new Query reset the 'joinpoint' of this Query reset back to the starting mapper. Subsequent generative calls will @@ -1111,9 +1165,9 @@ class Query(object): the root. """ - self.__reset_joinpoint() + self._reset_joinpoint() - @_generative(__no_clauseelement_condition) + @_generative(_no_clauseelement_condition) def select_from(self, from_obj): """Set the `from_obj` parameter of the query and return the newly resulting ``Query``. This replaces the table which this Query selects @@ -1134,7 +1188,7 @@ class Query(object): from_obj = from_obj[-1] if not isinstance(from_obj, expression.FromClause): raise sa_exc.ArgumentError("select_from() accepts FromClause objects only.") - self.__set_select_from(from_obj) + self._set_select_from(from_obj) def __getitem__(self, item): if isinstance(item, slice): @@ -1157,7 +1211,7 @@ class Query(object): else: return list(self[item:item+1])[0] - @_generative(__no_statement_condition) + @_generative(_no_statement_condition) def slice(self, start, stop): """apply LIMIT/OFFSET to the ``Query`` based on a range and return the newly resulting ``Query``.""" if start is not None and stop is not None: @@ -1168,7 +1222,7 @@ class Query(object): elif start is not None and stop is None: self._offset = (self._offset or 0) + start - @_generative(__no_statement_condition) + @_generative(_no_statement_condition) def limit(self, limit): """Apply a ``LIMIT`` to the query and return the newly resulting @@ -1177,7 +1231,7 @@ class Query(object): """ self._limit = limit - @_generative(__no_statement_condition) + @_generative(_no_statement_condition) def offset(self, offset): """Apply an ``OFFSET`` to the query and return the newly resulting ``Query``. @@ -1185,7 +1239,7 @@ class Query(object): """ self._offset = offset - @_generative(__no_statement_condition) + @_generative(_no_statement_condition) def distinct(self): """Apply a ``DISTINCT`` to the query and return the newly resulting ``Query``. @@ -1201,7 +1255,7 @@ class Query(object): """ return list(self) - @_generative(__no_clauseelement_condition) + @_generative(_no_clauseelement_condition) def from_statement(self, statement): """Execute the given SELECT statement and return results. @@ -1402,7 +1456,7 @@ class Query(object): if refresh_state is None: q = self._clone() - q.__no_criterion_condition("get") + q._no_criterion_condition("get") else: q = self._clone() @@ -1424,7 +1478,7 @@ class Query(object): if lockmode is not None: q._lockmode = lockmode - q.__get_options( + q._get_options( populate_existing=bool(refresh_state), version_check=(lockmode is not None), only_load_props=only_load_props, @@ -1454,18 +1508,24 @@ class Query(object): kwargs.get('distinct', False)) def count(self): - """Apply this query's criterion to a SELECT COUNT statement. - - If column expressions or LIMIT/OFFSET/DISTINCT are present, - the query "SELECT count(1) FROM (SELECT ...)" is issued, - so that the result matches the total number of rows - this query would return. For mapped entities, - the primary key columns of each is written to the - columns clause of the nested SELECT statement. - - For a Query which is only against mapped entities, - a simpler "SELECT count(1) FROM table1, table2, ... - WHERE criterion" is issued. + """Return a count of rows this Query would return. + + For simple entity queries, count() issues + a SELECT COUNT, and will specifically count the primary + key column of the first entity only. If the query uses + LIMIT, OFFSET, or DISTINCT, count() will wrap the statement + generated by this Query in a subquery, from which a SELECT COUNT + is issued, so that the contract of "how many rows + would be returned?" is honored. + + For queries that request specific columns or expressions, + count() again makes no assumptions about those expressions + and will wrap everything in a subquery. Therefore, + ``Query.count()`` is usually not what you want in this case. + To count specific columns, often in conjunction with + GROUP BY, use ``func.count()`` as an individual column expression + instead of ``Query.count()``. See the ORM tutorial + for an example. """ should_nest = [self._should_nest_selectable] @@ -2027,7 +2087,9 @@ class _ColumnEntity(_QueryEntity): return if not isinstance(column, sql.ColumnElement): - raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column) + raise sa_exc.InvalidRequestError( + "SQL expression, column, or mapped entity expected - got '%r'" % column + ) # if the Column is unnamed, give it a # label() so that mutable column expressions diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ea5aae645e..d3d653de4f 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -638,9 +638,9 @@ class Session(object): If no transaction is in progress, this method is a pass-through. This method rolls back the current transaction or nested transaction - regardless of subtransactions being in effect. All subtrasactions up + regardless of subtransactions being in effect. All subtransactions up to the first real transaction are closed. Subtransactions occur when - begin() is called mulitple times. + begin() is called multiple times. """ if self.transaction is None: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d2ab214666..902658a0e4 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -33,7 +33,7 @@ def _register_attribute(strategy, mapper, useobject, prop = strategy.parent_property attribute_ext = util.to_list(prop.extension) or [] - + if useobject and prop.single_parent: attribute_ext.append(_SingleParentValidator(prop)) @@ -370,13 +370,16 @@ class LazyLoader(AbstractRelationLoader): def init_class_attribute(self, mapper): self.is_class_level = True - + # MANYTOONE currently only needs the "old" value for delete-orphan + # cascades. the required _SingleParentValidator will enable active_history + # in that case. otherwise we don't need the "old" value during backref operations. _register_attribute(self, mapper, useobject=True, callable_=self._class_level_loader, uselist = self.parent_property.uselist, typecallable = self.parent_property.collection_class, + active_history = self.parent_property.direction is not interfaces.MANYTOONE, ) def lazy_clause(self, state, reverse_direction=False, alias_secondary=False, adapt_source=None): diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index e2bc3fde53..bca6b4f463 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -33,7 +33,9 @@ class UOWEventHandler(interfaces.AttributeExtension): """An event handler added to all relation attributes which handles session cascade operations. """ - + + active_history = False + def __init__(self, key): self.key = key diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 252fa8407f..346bf884af 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -65,14 +65,9 @@ class SchemaItem(visitors.Visitable): def __repr__(self): return "%s()" % self.__class__.__name__ - @property + @util.memoized_property def info(self): - try: - return self._info - except AttributeError: - self._info = {} - return self._info - + return {} def _get_table_key(name, schema): if schema is None: @@ -223,8 +218,8 @@ class Table(SchemaItem, expression.TableClause): self.quote = kwargs.pop('quote', None) self.quote_schema = kwargs.pop('quote_schema', None) - if kwargs.get('info'): - self._info = kwargs.pop('info') + if 'info' in kwargs: + self.info = kwargs.pop('info') self._prefixes = kwargs.pop('prefixes', []) @@ -263,7 +258,7 @@ class Table(SchemaItem, expression.TableClause): setattr(self, key, kwargs.pop(key)) if 'info' in kwargs: - self._info = kwargs.pop('info') + self.info = kwargs.pop('info') self._extra_kwargs(**kwargs) self._init_items(*args) @@ -614,8 +609,9 @@ class Column(SchemaItem, expression.ColumnClause): util.set_creation_order(self) - if kwargs.get('info'): - self._info = kwargs.pop('info') + if 'info' in kwargs: + self.info = kwargs.pop('info') + if kwargs: raise exc.ArgumentError( "Unknown arguments passed to Column: " + repr(kwargs.keys())) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8899486546..66dc84f194 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -412,8 +412,8 @@ class SQLCompiler(engine.Compiled): else: return text - def visit_unary(self, unary, **kwargs): - s = self.process(unary.element) + def visit_unary(self, unary, **kw): + s = self.process(unary.element, **kw) if unary.operator: s = OPERATORS[unary.operator] + s if unary.modifier: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 6c116e5bc0..2da5184e69 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -407,8 +407,8 @@ def not_(clause): def distinct(expr): """Return a ``DISTINCT`` clause.""" - - return _UnaryExpression(expr, operator=operators.distinct_op) + expr = _literal_as_binds(expr) + return _UnaryExpression(expr, operator=operators.distinct_op, type_=expr.type) def between(ctest, cleft, cright): """Return a ``BETWEEN`` predicate clause. @@ -1568,8 +1568,7 @@ class _CompareMixin(ColumnOperators): def distinct(self): """Produce a DISTINCT clause, i.e. ``DISTINCT ``""" - - return _UnaryExpression(self, operator=operators.distinct_op) + return _UnaryExpression(self, operator=operators.distinct_op, type_=self.type) def between(self, cleft, cright): """Produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f1f329b5e2..ac95c3a209 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -53,24 +53,21 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join tables = [] _visitors = {} - def visit_something(elem): - tables.append(elem) - if include_selects: - _visitors['select'] = _visitors['compound_select'] = visit_something + _visitors['select'] = _visitors['compound_select'] = tables.append if include_joins: - _visitors['join'] = visit_something + _visitors['join'] = tables.append if include_aliases: - _visitors['alias'] = visit_something + _visitors['alias'] = tables.append if check_columns: def visit_column(column): tables.append(column.table) _visitors['column'] = visit_column - _visitors['table'] = visit_something + _visitors['table'] = tables.append visitors.traverse(clause, {'column_collections':False}, _visitors) return tables diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 2537eb695e..d8a541abf0 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -1178,6 +1178,6 @@ class BinaryTest(TestBase, AssertsExecutionResults): def load_stream(self, name, len=3000): f = os.path.join(os.path.dirname(__file__), "..", name) - return file(f).read(len) + return file(f, 'rb').read(len) diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py index 9f753039a5..784a7b9ce6 100644 --- a/test/engine/test_metadata.py +++ b/test/engine/test_metadata.py @@ -162,3 +162,29 @@ class TableOptionsTest(TestBase, AssertsCompiledSQL): "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" ) + def test_table_info(self): + metadata = MetaData() + t1 = Table('foo', metadata, info={'x':'y'}) + t2 = Table('bar', metadata, info={}) + t3 = Table('bat', metadata) + assert t1.info == {'x':'y'} + assert t2.info == {} + assert t3.info == {} + for t in (t1, t2, t3): + t.info['bar'] = 'zip' + assert t.info['bar'] == 'zip' + +class ColumnOptionsTest(TestBase): + def test_column_info(self): + + c1 = Column('foo', info={'x':'y'}) + c2 = Column('bar', info={}) + c3 = Column('bat') + assert c1.info == {'x':'y'} + assert c2.info == {} + assert c3.info == {} + + for c in (c1, c2, c3): + c.info['bar'] = 'zip' + assert c.info['bar'] == 'zip' + diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 7ef749c965..4a5775218d 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1,4 +1,7 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test.testing import eq_, assert_raises +import copy +import pickle + from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection @@ -15,12 +18,15 @@ class DictCollection(dict): def remove(self, obj): del self[obj.foo] + class SetCollection(set): pass + class ListCollection(list): pass + class ObjectCollection(object): def __init__(self): self.values = list() @@ -33,6 +39,21 @@ class ObjectCollection(object): def __iter__(self): return iter(self.values) + +class Parent(object): + kids = association_proxy('children', 'name') + def __init__(self, name): + self.name = name + +class Child(object): + def __init__(self, name): + self.name = name + +class KVChild(object): + def __init__(self, name, value): + self.name = name + self.value = value + class _CollectionOperations(TestBase): def setup(self): collection_class = self.collection_class @@ -837,29 +858,23 @@ class ReconstitutionTest(TestBase): metadata.create_all() parents.insert().execute(name='p1') - class Parent(object): - kids = association_proxy('children', 'name') - def __init__(self, name): - self.name = name - - class Child(object): - def __init__(self, name): - self.name = name - - mapper(Parent, parents, properties=dict(children=relation(Child))) - mapper(Child, children) self.metadata = metadata - self.Parent = Parent - + self.parents = parents + self.children = children + def teardown(self): self.metadata.drop_all() + clear_mappers() def test_weak_identity_map(self): + mapper(Parent, self.parents, properties=dict(children=relation(Child))) + mapper(Child, self.children) + session = create_session(weak_identity_map=True) def add_child(parent_name, child_name): - parent = (session.query(self.Parent). + parent = (session.query(Parent). filter_by(name=parent_name)).one() parent.kids.append(child_name) @@ -869,12 +884,14 @@ class ReconstitutionTest(TestBase): add_child('p1', 'c2') session.flush() - p = session.query(self.Parent).filter_by(name='p1').one() + p = session.query(Parent).filter_by(name='p1').one() assert set(p.kids) == set(['c1', 'c2']), p.kids def test_copy(self): - import copy - p = self.Parent('p1') + mapper(Parent, self.parents, properties=dict(children=relation(Child))) + mapper(Child, self.children) + + p = Parent('p1') p.kids.extend(['c1', 'c2']) p_copy = copy.copy(p) del p @@ -882,4 +899,51 @@ class ReconstitutionTest(TestBase): assert set(p_copy.kids) == set(['c1', 'c2']), p.kids + def test_pickle_list(self): + mapper(Parent, self.parents, properties=dict(children=relation(Child))) + mapper(Child, self.children) + + p = Parent('p1') + p.kids.extend(['c1', 'c2']) + + r1 = pickle.loads(pickle.dumps(p)) + assert r1.kids == ['c1', 'c2'] + r2 = pickle.loads(pickle.dumps(p.kids)) + assert r2 == ['c1', 'c2'] + + def test_pickle_set(self): + mapper(Parent, self.parents, properties=dict(children=relation(Child, collection_class=set))) + mapper(Child, self.children) + + p = Parent('p1') + p.kids.update(['c1', 'c2']) + + r1 = pickle.loads(pickle.dumps(p)) + assert r1.kids == set(['c1', 'c2']) + + r2 = pickle.loads(pickle.dumps(p.kids)) + assert r2 == set(['c1', 'c2']) + + def test_pickle_dict(self): + mapper(Parent, self.parents, properties=dict( + children=relation(KVChild, collection_class=collections.mapped_collection(PickleKeyFunc('name'))) + )) + mapper(KVChild, self.children) + + p = Parent('p1') + p.kids.update({'c1':'v1', 'c2':'v2'}) + assert p.kids == {'c1':'c1', 'c2':'c2'} + + r1 = pickle.loads(pickle.dumps(p)) + assert r1.kids == {'c1':'c1', 'c2':'c2'} + + r2 = pickle.loads(pickle.dumps(p.kids)) + assert r2 == {'c1':'c1', 'c2':'c2'} + +class PickleKeyFunc(object): + def __init__(self, name): + self.name = name + + def __call__(self, obj): + return getattr(obj, self.name) \ No newline at end of file diff --git a/test/ext/test_declarative.py b/test/ext/test_declarative.py index 6bf709dfc7..9ca8356918 100644 --- a/test/ext/test_declarative.py +++ b/test/ext/test_declarative.py @@ -449,6 +449,15 @@ class DeclarativeTest(DeclarativeTestBase): define) def test_table_args(self): + + def err(): + class Foo(Base): + __tablename__ = 'foo' + __table_args__ = (ForeignKeyConstraint(['id'], ['foo.id']),) + id = Column('id', Integer, primary_key=True) + + assert_raises_message(sa.exc.ArgumentError, "Tuple form of __table_args__ is ", err) + class Foo(Base): __tablename__ = 'foo' __table_args__ = {'mysql_engine':'InnoDB'} diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 435f26cbae..e9cd6093d2 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -74,20 +74,33 @@ class FalseDiscriminatorTest(_base.MappedTest): global t1 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('type', Integer, nullable=False)) + Column('type', Boolean, nullable=False)) - def test_false_discriminator(self): + def test_false_on_sub(self): class Foo(object):pass class Bar(Foo):pass - mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=1) - mapper(Bar, inherits=Foo, polymorphic_identity=0) + mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=True) + mapper(Bar, inherits=Foo, polymorphic_identity=False) sess = create_session() - f1 = Bar() - sess.add(f1) + b1 = Bar() + sess.add(b1) sess.flush() - assert f1.type == 0 + assert b1.type is False sess.expunge_all() assert isinstance(sess.query(Foo).one(), Bar) + + def test_false_on_base(self): + class Ding(object):pass + class Bat(Ding):pass + mapper(Ding, t1, polymorphic_on=t1.c.type, polymorphic_identity=False) + mapper(Bat, inherits=Ding, polymorphic_identity=True) + sess = create_session() + d1 = Ding() + sess.add(d1) + sess.flush() + assert d1.type is False + sess.expunge_all() + assert sess.query(Ding).one() is not None class PolymorphicSynonymTest(_base.MappedTest): @classmethod @@ -900,10 +913,8 @@ class OverrideColKeyTest(_base.MappedTest): assert sess.query(Sub).get(s1.base_id).data == "this is base" class OptimizedLoadTest(_base.MappedTest): - """test that the 'optimized load' routine doesn't crash when - a column in the join condition is not available. + """tests for the "optimized load" routine.""" - """ @classmethod def define_tables(cls, metadata): global base, sub @@ -918,7 +929,10 @@ class OptimizedLoadTest(_base.MappedTest): ) def test_optimized_passes(self): - class Base(object): + """"test that the 'optimized load' routine doesn't crash when + a column in the join condition is not available.""" + + class Base(_base.BasicEntity): pass class Sub(Base): pass @@ -928,21 +942,66 @@ class OptimizedLoadTest(_base.MappedTest): # redefine Sub's "id" to favor the "id" col in the subtable. # "id" is also part of the primary join condition mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id}) - sess = create_session() - s1 = Sub() - s1.data = 's1data' - s1.sub = 's1sub' + sess = sessionmaker()() + s1 = Sub(data='s1data', sub='s1sub') sess.add(s1) - sess.flush() + sess.commit() sess.expunge_all() # load s1 via Base. s1.id won't populate since it's relative to # the "sub" table. The optimized load kicks in and tries to # generate on the primary join, but cannot since "id" is itself unloaded. # the optimized load needs to return "None" so regular full-row loading proceeds - s1 = sess.query(Base).get(s1.id) + s1 = sess.query(Base).first() assert s1.sub == 's1sub' + def test_column_expression(self): + class Base(_base.BasicEntity): + pass + class Sub(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={ + 'concat':column_property(sub.c.sub + "|" + sub.c.sub) + }) + sess = sessionmaker()() + s1 = Sub(data='s1data', sub='s1sub') + sess.add(s1) + sess.commit() + sess.expunge_all() + s1 = sess.query(Base).first() + assert s1.concat == 's1sub|s1sub' + + def test_column_expression_joined(self): + class Base(_base.ComparableEntity): + pass + class Sub(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={ + 'concat':column_property(base.c.data + "|" + sub.c.sub) + }) + sess = sessionmaker()() + s1 = Sub(data='s1data', sub='s1sub') + s2 = Sub(data='s2data', sub='s2sub') + s3 = Sub(data='s3data', sub='s3sub') + sess.add_all([s1, s2, s3]) + sess.commit() + sess.expunge_all() + # query a bunch of rows to ensure there's no cartesian + # product against "base" occurring, it is in fact + # detecting that "base" needs to be in the join + # criterion + eq_( + sess.query(Base).order_by(Base.id).all(), + [ + Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'), + Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'), + Sub(data='s3data', sub='s3sub', concat='s3data|s3sub') + ] + ) + + class PKDiscriminatorTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): diff --git a/test/orm/test_backref_mutations.py b/test/orm/test_backref_mutations.py new file mode 100644 index 0000000000..1ecf0275a9 --- /dev/null +++ b/test/orm/test_backref_mutations.py @@ -0,0 +1,474 @@ +""" +a series of tests which assert the behavior of moving objects between collections +and scalar attributes resulting in the expected state w.r.t. backrefs, add/remove +events, etc. + +there's a particular focus on collections that have "uselist=False", since in these +cases the re-assignment of an attribute means the previous owner needs an +UPDATE in the database. + +""" + +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref, sessionmaker +from sqlalchemy.orm import attributes, exc as orm_exc +from sqlalchemy.test import testing +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures + +class O2MCollectionTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, backref="user"), + )) + + @testing.resolve_artifact_names + def test_collection_move_hitslazy(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + a2 = Address(email_address="address2") + a3 = Address(email_address="address3") + u1= User(name='jack', addresses=[a1, a2, a3]) + u2= User(name='ed') + sess.add_all([u1, a1, a2, a3]) + sess.commit() + + #u1.addresses + + def go(): + u2.addresses.append(a1) + u2.addresses.append(a2) + u2.addresses.append(a3) + self.assert_sql_count(testing.db, go, 0) + + @testing.resolve_artifact_names + def test_collection_move_preloaded(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', addresses=[a1]) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + # load u1.addresses collection + u1.addresses + + u2.addresses.append(a1) + + # backref fires + assert a1.user is u2 + + # doesn't extend to the previous collection tho, + # which was already loaded. + # flushing at this point means its anyone's guess. + assert a1 in u1.addresses + assert a1 in u2.addresses + + @testing.resolve_artifact_names + def test_collection_move_notloaded(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', addresses=[a1]) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + u2.addresses.append(a1) + + # backref fires + assert a1.user is u2 + + # u1.addresses wasn't loaded, + # so when it loads its correct + assert a1 not in u1.addresses + assert a1 in u2.addresses + + @testing.resolve_artifact_names + def test_collection_move_commitfirst(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', addresses=[a1]) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + # load u1.addresses collection + u1.addresses + + u2.addresses.append(a1) + + # backref fires + assert a1.user is u2 + + # everything expires, no changes in + # u1.addresses, so all is fine + sess.commit() + assert a1 not in u1.addresses + assert a1 in u2.addresses + + @testing.resolve_artifact_names + def test_scalar_move_preloaded(self): + sess = sessionmaker()() + + u1 = User(name='jack') + u2 = User(name='ed') + a1 = Address(email_address='a1') + a1.user = u1 + sess.add_all([u1, u2, a1]) + sess.commit() + + # u1.addresses is loaded + u1.addresses + + # direct set - the fetching of the + # "old" u1 here allows the backref + # to remove it from the addresses collection + a1.user = u2 + + assert a1 not in u1.addresses + assert a1 in u2.addresses + + + @testing.resolve_artifact_names + def test_scalar_move_notloaded(self): + sess = sessionmaker()() + + u1 = User(name='jack') + u2 = User(name='ed') + a1 = Address(email_address='a1') + a1.user = u1 + sess.add_all([u1, u2, a1]) + sess.commit() + + # direct set - the fetching of the + # "old" u1 here allows the backref + # to remove it from the addresses collection + a1.user = u2 + + assert a1 not in u1.addresses + assert a1 in u2.addresses + + @testing.resolve_artifact_names + def test_scalar_move_commitfirst(self): + sess = sessionmaker()() + + u1 = User(name='jack') + u2 = User(name='ed') + a1 = Address(email_address='a1') + a1.user = u1 + sess.add_all([u1, u2, a1]) + sess.commit() + + # u1.addresses is loaded + u1.addresses + + # direct set - the fetching of the + # "old" u1 here allows the backref + # to remove it from the addresses collection + a1.user = u2 + + sess.commit() + assert a1 not in u1.addresses + assert a1 in u2.addresses + +class O2OScalarBackrefMoveTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + mapper(User, users, properties = { + 'address':relation(Address, backref=backref("user"), uselist=False) + }) + + @testing.resolve_artifact_names + def test_collection_move_preloaded(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + # load u1.address + u1.address + + # reassign + u2.address = a1 + assert u2.address is a1 + + # backref fires + assert a1.user is u2 + + # doesn't extend to the previous attribute tho. + # flushing at this point means its anyone's guess. + assert u1.address is a1 + assert u2.address is a1 + + @testing.resolve_artifact_names + def test_scalar_move_preloaded(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + a2 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + sess.add_all([u1, a1, a2]) + sess.commit() # everything is expired + + # load a1.user + a1.user + + # reassign + a2.user = u1 + + # backref fires + assert u1.address is a2 + + # stays on both sides + assert a1.user is u1 + assert a2.user is u1 + + @testing.resolve_artifact_names + def test_collection_move_notloaded(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + # reassign + u2.address = a1 + assert u2.address is a1 + + # backref fires + assert a1.user is u2 + + # u1.address loads now after a flush + assert u1.address is None + assert u2.address is a1 + + @testing.resolve_artifact_names + def test_scalar_move_notloaded(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + a2 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + sess.add_all([u1, a1, a2]) + sess.commit() # everything is expired + + # reassign + a2.user = u1 + + # backref fires + assert u1.address is a2 + + # stays on both sides + assert a1.user is u1 + assert a2.user is u1 + + @testing.resolve_artifact_names + def test_collection_move_commitfirst(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + # load u1.address + u1.address + + # reassign + u2.address = a1 + assert u2.address is a1 + + # backref fires + assert a1.user is u2 + + # the commit cancels out u1.addresses + # being loaded, on next access its fine. + sess.commit() + assert u1.address is None + assert u2.address is a1 + + @testing.resolve_artifact_names + def test_scalar_move_commitfirst(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + a2 = Address(email_address="address2") + u1 = User(name='jack', address=a1) + + sess.add_all([u1, a1, a2]) + sess.commit() # everything is expired + + # load + assert a1.user is u1 + + # reassign + a2.user = u1 + + # backref fires + assert u1.address is a2 + + # didnt work this way tho + assert a1.user is u1 + + # moves appropriately after commit + sess.commit() + assert u1.address is a2 + assert a1.user is None + assert a2.user is u1 + +class O2OScalarMoveTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + mapper(User, users, properties = { + 'address':relation(Address, uselist=False) + }) + + @testing.resolve_artifact_names + def test_collection_move_commitfirst(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + u2 = User(name='ed') + sess.add_all([u1, u2]) + sess.commit() # everything is expired + + # load u1.address + u1.address + + # reassign + u2.address = a1 + assert u2.address is a1 + + # the commit cancels out u1.addresses + # being loaded, on next access its fine. + sess.commit() + assert u1.address is None + assert u2.address is a1 + +class O2OScalarOrphanTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Address, addresses) + mapper(User, users, properties = { + 'address':relation(Address, uselist=False, + backref=backref('user', single_parent=True, cascade="all, delete-orphan")) + }) + + @testing.resolve_artifact_names + def test_m2o_event(self): + sess = sessionmaker()() + a1 = Address(email_address="address1") + u1 = User(name='jack', address=a1) + + sess.add(u1) + sess.commit() + sess.expunge(u1) + + u2= User(name='ed') + # the _SingleParent extension sets the backref get to "active" ! + # u1 gets loaded and deleted + u2.address = a1 + sess.commit() + assert sess.query(User).count() == 1 + + +class M2MScalarMoveTest(_fixtures.FixtureTest): + run_inserts = None + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + mapper(Item, items, properties={ + 'keyword':relation(Keyword, secondary=item_keywords, uselist=False, backref=backref("item", uselist=False)) + }) + mapper(Keyword, keywords) + + @testing.resolve_artifact_names + def test_collection_move_preloaded(self): + sess = sessionmaker()() + + k1 = Keyword(name='k1') + i1 = Item(description='i1', keyword=k1) + i2 = Item(description='i2') + + sess.add_all([i1, i2, k1]) + sess.commit() # everything is expired + + # load i1.keyword + assert i1.keyword is k1 + + i2.keyword = k1 + + assert k1.item is i2 + + # nothing happens. + assert i1.keyword is k1 + assert i2.keyword is k1 + + @testing.resolve_artifact_names + def test_collection_move_notloaded(self): + sess = sessionmaker()() + + k1 = Keyword(name='k1') + i1 = Item(description='i1', keyword=k1) + i2 = Item(description='i2') + + sess.add_all([i1, i2, k1]) + sess.commit() # everything is expired + + i2.keyword = k1 + + assert k1.item is i2 + + assert i1.keyword is None + assert i2.keyword is k1 + + @testing.resolve_artifact_names + def test_collection_move_commit(self): + sess = sessionmaker()() + + k1 = Keyword(name='k1') + i1 = Item(description='i1', keyword=k1) + i2 = Item(description='i2') + + sess.add_all([i1, i2, k1]) + sess.commit() # everything is expired + + # load i1.keyword + assert i1.keyword is k1 + + i2.keyword = k1 + + assert k1.item is i2 + + sess.commit() + assert i1.keyword is None + assert i2.keyword is k1 diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 4ce36f1ff5..cb60bef69a 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -225,6 +225,11 @@ class InvalidGenerationsTest(QueryTest): assert_raises(sa_exc.InvalidRequestError, q.having, 'foo') + q.enable_assertions(False).join("addresses") + q.enable_assertions(False).filter(User.name=='ed') + q.enable_assertions(False).order_by('foo') + q.enable_assertions(False).group_by('foo') + def test_no_from(self): s = create_session() @@ -236,6 +241,10 @@ class InvalidGenerationsTest(QueryTest): q = s.query(User).order_by(User.id) assert_raises(sa_exc.InvalidRequestError, q.select_from, users) + + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) + + q.enable_assertions(False).select_from(users) # this is fine, however q.from_self() diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 9d3d785f95..bbc399aa6b 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -481,6 +481,11 @@ class QueryTest(TestBase): self.assert_(r['query_users.user_id']) == 1 self.assert_(r['query_users.user_name']) == "john" + # unary experssions + r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().first() + eq_(r[users.c.user_name], 'jack') + eq_(r.user_name, 'jack') + def test_result_case_sensitivity(self): """test name normalization for result sets.""" @@ -493,6 +498,7 @@ class QueryTest(TestBase): assert row.keys() == ["case_insensitive", "CaseSensitive"] + def test_row_as_args(self): users.insert().execute(user_id=1, user_name='john') r = users.select(users.c.user_id==1).execute().first() diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 841dda53fe..9c90549e29 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -496,6 +496,16 @@ class ExpressionTest(TestBase, AssertsExecutionResults): # this one relies upon anonymous labeling to assemble result # processing rules on the column. assert testing.db.execute(select([expr])).scalar() == -15 + + def test_distinct(self): + s = select([distinct(test_table.c.avalue)]) + eq_(testing.db.execute(s).scalar(), 25) + + s = select([test_table.c.avalue.distinct()]) + eq_(testing.db.execute(s).scalar(), 25) + + assert distinct(test_table.c.data).type == test_table.c.data.type + assert test_table.c.data.distinct().type == test_table.c.data.type class DateTest(TestBase, AssertsExecutionResults): @classmethod