]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged -r6172:6204 of trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Jul 2009 18:58:54 +0000 (18:58 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Jul 2009 18:58:54 +0000 (18:58 +0000)
25 files changed:
CHANGES
doc/build/ormtutorial.rst
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/dialect/test_mssql.py
test/engine/test_metadata.py
test/ext/test_associationproxy.py
test/ext/test_declarative.py
test/orm/inheritance/test_basic.py
test/orm/test_backref_mutations.py [new file with mode: 0644]
test/orm/test_query.py
test/sql/test_query.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 78f426aaa46c1adcc750625b0f3204dff51a504d..0a4c34b3d596b92f06299a4928c8b813a55c96b2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -16,6 +16,35 @@ CHANGES
       during a flush.  This is currently to support 
       many-to-many relations from concrete inheritance setups.
       Outside of that use case, YMMV.  [ticket:1477]
+    
+    - Squeezed a few more unnecessary "lazy loads" out of 
+      relation().  When a collection is mutated, many-to-one
+      backrefs on the other side will not fire off to load
+      the "old" value, unless "single_parent=True" is set.  
+      A direct assignment of a many-to-one still loads 
+      the "old" value in order to update backref collections 
+      on that value, which may be present in the session 
+      already, thus maintaining the 0.5 behavioral contract.
+      [ticket:1483]
+      
+    - Fixed bug whereby a load/refresh of joined table
+      inheritance attributes which were based on 
+      column_property() or similar would fail to evaluate.
+      [ticket:1480]
+      
+    - Improved error message when query() is called with
+      a non-SQL /entity expression. [ticket:1476]
+    
+    - Using False or 0 as a polymorphic discriminator now
+      works on the base class as well as a subclass.
+      [ticket:1440]
+
+    - Added enable_assertions(False) to Query which disables
+      the usual assertions for expected state - used
+      by Query subclasses to engineer custom state.
+      [ticket:1424].  See
+      http://www.sqlalchemy.org/trac/wiki/UsageRecipes/PreFilteredQuery
+      for an example.
       
 - sql
     - Fixed a bug in extract() introduced in 0.5.4 whereby
@@ -23,6 +52,24 @@ CHANGES
       ClauseElement, causing various errors within more 
       complex SQL transformations.
       
+    - Unary expressions such as DISTINCT propagate their 
+      type handling to result sets, allowing conversions like
+      unicode and such to take place.  [ticket:1420]
+    
+    - Fixed bug in Table and Column whereby passing empty
+      dict for "info" argument would raise an exception.
+      [ticket:1482]
+      
+- ext
+   - The collection proxies produced by associationproxy are now
+     pickleable.  A user-defined proxy_factory however
+     is still not pickleable unless it defines __getstate__
+     and __setstate__. [ticket:1446]
+
+   - Declarative will raise an informative exception if 
+     __table_args__ is passed as a tuple with no dict argument.
+     Improved documentation.  [ticket:1468]
+     
 0.5.5
 =======
 - general
index acdbba149eb6bbd24da10f4158156142710dcc36..c10d457f14949ae176ac4a9d1516ff6afbb6f374 100644 (file)
@@ -567,6 +567,54 @@ To use an entirely string-based statement, using ``from_statement()``; just ensu
     ['ed']
     {stop}[<User('ed','Ed Jones', 'f8s7ccs')>]
 
+Counting
+--------
+
+``Query`` includes a convenience method for counting called ``count()``:
+
+.. sourcecode:: python+sql
+
+    {sql}>>> session.query(User).filter(User.name.like('%ed')).count()
+    SELECT count(1) AS count_1 
+    FROM users 
+    WHERE users.name LIKE ?
+    ['%ed']
+    {stop}2
+    
+The ``count()`` method is used to determine how many rows the SQL statement would return, and is mainly intended to return a simple count of a single type of entity, in this case ``User``.   For more complicated sets of columns or entities where the "thing to be counted" needs to be indicated more specifically, ``count()`` is probably not what you want.  Below, a query for individual columns does return the expected result:
+
+.. sourcecode:: python+sql
+
+    {sql}>>> session.query(User.id, User.name).filter(User.name.like('%ed')).count()
+    SELECT count(1) AS count_1 
+    FROM (SELECT users.id AS users_id, users.name AS users_name 
+    FROM users 
+    WHERE users.name LIKE ?) AS anon_1
+    ['%ed']
+    {stop}2
+
+...but if you look at the generated SQL, SQLAlchemy saw that we were placing individual column expressions and decided to wrap whatever it was we were doing in a subquery, so as to be assured that it returns the "number of rows".   This defensive behavior is not really needed here and in other cases is not what we want at all, such as if we wanted a grouping of counts per name:
+
+.. sourcecode:: python+sql
+
+    {sql}>>> session.query(User.name).group_by(User.name).count()
+    SELECT count(1) AS count_1 
+    FROM (SELECT users.name AS users_name 
+    FROM users GROUP BY users.name) AS anon_1
+    []
+    {stop}4
+
+We don't want the number ``4``, we wanted some rows back.   So for detailed queries where you need to count something specific, use the ``func.count()`` function as a column expression:
+
+.. sourcecode:: python+sql
+    
+    >>> from sqlalchemy import func
+    {sql}>>> session.query(func.count(User.name), User.name).group_by(User.name).all()
+    SELECT count(users.name) AS count_1, users.name AS users_name 
+    FROM users GROUP BY users.name
+    {stop}[]
+    [(1, u'ed'), (1, u'fred'), (1, u'mary'), (1, u'wendy')]
+
 Building a Relation
 ====================
 
@@ -824,7 +872,7 @@ The ``Query`` is suitable for generating statements which can be used as subquer
         (SELECT user_id, count(*) AS address_count FROM addresses GROUP BY user_id) AS adr_count
         ON users.id=adr_count.user_id
 
-Using the ``Query``, we build a statement like this from the inside out.  The ``statement`` accessor returns a SQL expression representing the statement generated by a particular ``Query`` - this is an instance of a ``select()`` construct, which are described in :ref:`sql`::
+Using the ``Query``, we build a statement like this from the inside out.  The ``statement`` accessor returns a SQL expression representing the statement generated by a particular ``Query`` - this is an instance of a ``select()`` construct, which are described in :ref:`sqlexpression_toplevel`::
 
     >>> from sqlalchemy.sql import func
     >>> stmt = session.query(Address.user_id, func.count('*').label('address_count')).group_by(Address.user_id).subquery()
index 315142d8e0427119f052cc27b8f01187e6a01ed3..e126fe638d772bfad177a12b303caae9e6876ca6 100644 (file)
@@ -140,26 +140,14 @@ class AssociationProxy(object):
         return (orm.class_mapper(self.owning_class).
                 get_property(self.target_collection))
 
+    @property
     def target_class(self):
         """The class the proxy is attached to."""
         return self._get_property().mapper.class_
-    target_class = property(target_class)
 
     def _target_is_scalar(self):
         return not self._get_property().uselist
 
-    def _lazy_collection(self, weakobjref):
-        target = self.target_collection
-        del self
-        def lazy_collection():
-            obj = weakobjref()
-            if obj is None:
-                raise exceptions.InvalidRequestError(
-                    "stale association proxy, parent object has gone out of "
-                    "scope")
-            return getattr(obj, target)
-        return lazy_collection
-
     def __get__(self, obj, class_):
         if self.owning_class is None:
             self.owning_class = class_ and class_ or type(obj)
@@ -181,10 +169,10 @@ class AssociationProxy(object):
                     return proxy
             except AttributeError:
                 pass
-            proxy = self._new(self._lazy_collection(weakref.ref(obj)))
+            proxy = self._new(_lazy_collection(obj, self.target_collection))
             setattr(obj, self.key, (id(obj), proxy))
             return proxy
-
+    
     def __set__(self, obj, values):
         if self.owning_class is None:
             self.owning_class = type(obj)
@@ -238,13 +226,13 @@ class AssociationProxy(object):
             getter, setter = self.getset_factory(self.collection_class, self)
         else:
             getter, setter = self._default_getset(self.collection_class)
-
+        
         if self.collection_class is list:
-            return _AssociationList(lazy_collection, creator, getter, setter)
+            return _AssociationList(lazy_collection, creator, getter, setter, self)
         elif self.collection_class is dict:
-            return _AssociationDict(lazy_collection, creator, getter, setter)
+            return _AssociationDict(lazy_collection, creator, getter, setter, self)
         elif self.collection_class is set:
-            return _AssociationSet(lazy_collection, creator, getter, setter)
+            return _AssociationSet(lazy_collection, creator, getter, setter, self)
         else:
             raise exceptions.ArgumentError(
                 'could not guess which interface to use for '
@@ -252,6 +240,18 @@ class AssociationProxy(object):
                 'proxy_factory and proxy_bulk_set manually' %
                 (self.collection_class.__name__, self.target_collection))
 
+    def _inflate(self, proxy):
+        creator = self.creator and self.creator or self.target_class
+
+        if self.getset_factory:
+            getter, setter = self.getset_factory(self.collection_class, self)
+        else:
+            getter, setter = self._default_getset(self.collection_class)
+        
+        proxy.creator = creator
+        proxy.getter = getter
+        proxy.setter = setter
+
     def _set(self, proxy, values):
         if self.proxy_bulk_set:
             self.proxy_bulk_set(proxy, values)
@@ -266,12 +266,32 @@ class AssociationProxy(object):
                'no proxy_bulk_set supplied for custom '
                'collection_class implementation')
 
+class _lazy_collection(object):
+    def __init__(self, obj, target):
+        self.ref = weakref.ref(obj)
+        self.target = target
 
-class _AssociationList(object):
-    """Generic, converting, list-to-list proxy."""
-
-    def __init__(self, lazy_collection, creator, getter, setter):
-        """Constructs an _AssociationList.
+    def __call__(self):
+        obj = self.ref()
+        if obj is None:
+            raise exceptions.InvalidRequestError(
+               "stale association proxy, parent object has gone out of "
+               "scope")
+        return getattr(obj, self.target)
+
+    def __getstate__(self):
+        return {'obj':self.ref(), 'target':self.target}
+    
+    def __setstate__(self, state):
+        self.ref = weakref.ref(state['obj'])
+        self.target = state['target']
+
+class _AssociationCollection(object):
+    def __init__(self, lazy_collection, creator, getter, setter, parent):
+        """Constructs an _AssociationCollection.  
+        
+        This will always be a subclass of either _AssociationList,
+        _AssociationSet, or _AssociationDict.
 
         lazy_collection
           A callable returning a list-based collection of entities (usually an
@@ -296,9 +316,27 @@ class _AssociationList(object):
         self.creator = creator
         self.getter = getter
         self.setter = setter
+        self.parent = parent
 
     col = property(lambda self: self.lazy_collection())
 
+    def __len__(self):
+        return len(self.col)
+
+    def __nonzero__(self):
+        return bool(self.col)
+
+    def __getstate__(self):
+        return {'parent':self.parent, 'lazy_collection':self.lazy_collection}
+
+    def __setstate__(self, state):
+        self.parent = state['parent']
+        self.lazy_collection = state['lazy_collection']
+        self.parent._inflate(self)
+    
+class _AssociationList(_AssociationCollection):
+    """Generic, converting, list-to-list proxy."""
+
     def _create(self, value):
         return self.creator(value)
 
@@ -308,15 +346,6 @@ class _AssociationList(object):
     def _set(self, object, value):
         return self.setter(object, value)
 
-    def __len__(self):
-        return len(self.col)
-
-    def __nonzero__(self):
-        if self.col:
-            return True
-        else:
-            return False
-
     def __getitem__(self, index):
         return self._get(self.col[index])
 
@@ -494,39 +523,9 @@ class _AssociationList(object):
 
 
 _NotProvided = util.symbol('_NotProvided')
-class _AssociationDict(object):
+class _AssociationDict(_AssociationCollection):
     """Generic, converting, dict-to-dict proxy."""
 
-    def __init__(self, lazy_collection, creator, getter, setter):
-        """Constructs an _AssociationDict.
-
-        lazy_collection
-          A callable returning a dict-based collection of entities (usually an
-          object attribute managed by a SQLAlchemy relation())
-
-        creator
-          A function that creates new target entities.  Given two parameters:
-          key and value.  The assertion is assumed::
-
-            obj = creator(somekey, somevalue)
-            assert getter(somekey) == somevalue
-
-        getter
-          A function.  Given an associated object and a key, return the
-          'value'.
-
-        setter
-          A function.  Given an associated object, a key and a value, store
-          that value on the object.
-
-        """
-        self.lazy_collection = lazy_collection
-        self.creator = creator
-        self.getter = getter
-        self.setter = setter
-
-    col = property(lambda self: self.lazy_collection())
-
     def _create(self, key, value):
         return self.creator(key, value)
 
@@ -536,15 +535,6 @@ class _AssociationDict(object):
     def _set(self, object, key, value):
         return self.setter(object, key, value)
 
-    def __len__(self):
-        return len(self.col)
-
-    def __nonzero__(self):
-        if self.col:
-            return True
-        else:
-            return False
-
     def __getitem__(self, key):
         return self._get(self.col[key])
 
@@ -669,38 +659,9 @@ class _AssociationDict(object):
     del func_name, func
 
 
-class _AssociationSet(object):
+class _AssociationSet(_AssociationCollection):
     """Generic, converting, set-to-set proxy."""
 
-    def __init__(self, lazy_collection, creator, getter, setter):
-        """Constructs an _AssociationSet.
-
-        collection
-          A callable returning a set-based collection of entities (usually an
-          object attribute managed by a SQLAlchemy relation())
-
-        creator
-          A function that creates new target entities.  Given one parameter:
-          value.  The assertion is assumed::
-
-            obj = creator(somevalue)
-            assert getter(obj) == somevalue
-
-        getter
-          A function.  Given an associated object, return the 'value'.
-
-        setter
-          A function.  Given an associated object and a value, store that
-          value on the object.
-
-        """
-        self.lazy_collection = lazy_collection
-        self.creator = creator
-        self.getter = getter
-        self.setter = setter
-
-    col = property(lambda self: self.lazy_collection())
-
     def _create(self, value):
         return self.creator(value)
 
index 43369311b325ad12f94d71457de9d8c9acb9dbe7..c37211ac3dbb6f430d9f8f4523516dff130e4bb2 100644 (file)
@@ -214,29 +214,39 @@ ORM function::
 Table Configuration
 ===================
 
-As an alternative to ``__tablename__``, a direct :class:`~sqlalchemy.schema.Table` construct may be
-used.  The :class:`~sqlalchemy.schema.Column` objects, which in this case require their names, will be
-added to the mapping just like a regular mapping to a table::
+Table arguments other than the name, metadata, and mapped Column arguments 
+are specified using the ``__table_args__`` class attribute.   This attribute
+accommodates both positional as well as keyword arguments that are normally 
+sent to the :class:`~sqlalchemy.schema.Table` constructor.   The attribute can be specified
+in one of two forms.  One is as a dictionary::
 
     class MyClass(Base):
-        __table__ = Table('my_table', Base.metadata,
-            Column('id', Integer, primary_key=True),
-            Column('name', String(50))
-        )
+        __tablename__ = 'sometable'
+        __table_args__ = {'mysql_engine':'InnoDB'}
 
-Other table-based attributes include ``__table_args__``, which is
-either a dictionary as in::
+The other, a tuple of the form ``(arg1, arg2, ..., {kwarg1:value, ...})``, which 
+allows positional arguments to be specified as well (usually constraints)::
 
     class MyClass(Base):
         __tablename__ = 'sometable'
-        __table_args__ = {'mysql_engine':'InnoDB'}
-        
-or a dictionary-containing tuple in the form 
-``(arg1, arg2, ..., {kwarg1:value, ...})``, as in::
+        __table_args__ = (
+                ForeignKeyConstraint(['id'], ['remote_table.id']),
+                UniqueConstraint('foo'),
+                {'autoload':True}
+                )
+
+Note that the dictionary is required in the tuple form even if empty.
+
+As an alternative to ``__tablename__``, a direct :class:`~sqlalchemy.schema.Table` 
+construct may be used.  The :class:`~sqlalchemy.schema.Column` objects, which 
+in this case require their names, will be
+added to the mapping just like a regular mapping to a table::
 
     class MyClass(Base):
-        __tablename__ = 'sometable'
-        __table_args__ = (ForeignKeyConstraint(['id'], ['remote_table.id']), {'autoload':True})
+        __table__ = Table('my_table', Base.metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(50))
+        )
 
 Mapper Configuration
 ====================
@@ -468,6 +478,11 @@ def _as_declarative(cls, classname, dict_):
             elif isinstance(table_args, tuple):
                 args = table_args[0:-1]
                 table_kw = table_args[-1]
+                if len(table_args) < 2 or not isinstance(table_kw, dict):
+                    raise exceptions.ArgumentError(
+                                "Tuple form of __table_args__ is "
+                                "(arg1, arg2, arg3, ..., {'kw1':val1, 'kw2':val2, ...})"
+                            )
             else:
                 args, table_kw = (), {}
 
index 09a05b56fe705ac4decea0c1129cc26a800a9ec2..f6947dbc1156abf2e723ad0a422e7f724b36d61f 100644 (file)
@@ -262,7 +262,6 @@ class AttributeImpl(object):
         active_history
           indicates that get_history() should always return the "old" value,
           even if it means executing a lazy callable upon attribute change.
-          This flag is set to True if any extensions are present.
 
         parent_token
           Usually references the MapperProperty, used as a key for
@@ -286,6 +285,10 @@ class AttributeImpl(object):
         else:
             self.is_equal = compare_function
         self.extensions = util.to_list(extension or [])
+        for e in self.extensions:
+            if e.active_history:
+                active_history = True
+                break
         self.active_history = active_history
         self.expire_missing = expire_missing
         
@@ -383,12 +386,12 @@ class AttributeImpl(object):
             return self.initialize(state, dict_)
 
     def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, dict_, value, initiator)
+        self.set(state, dict_, value, initiator, passive=passive)
 
     def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, dict_, None, initiator)
+        self.set(state, dict_, None, initiator, passive=passive)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         raise NotImplementedError()
 
     def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
@@ -421,7 +424,7 @@ class ScalarAttributeImpl(AttributeImpl):
     def delete(self, state, dict_):
 
         # TODO: catch key errors, convert to attributeerror?
-        if self.active_history or self.extensions:
+        if self.active_history:
             old = self.get(state, dict_)
         else:
             old = dict_.get(self.key, NO_VALUE)
@@ -436,11 +439,11 @@ class ScalarAttributeImpl(AttributeImpl):
         return History.from_attribute(
             self, state, dict_.get(self.key, NO_VALUE))
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
-        if self.active_history or self.extensions:
+        if self.active_history:
             old = self.get(state, dict_)
         else:
             old = dict_.get(self.key, NO_VALUE)
@@ -511,7 +514,7 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
         ScalarAttributeImpl.delete(self, state, dict_)
         state.mutable_dict.pop(self.key)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
@@ -559,7 +562,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
             else:
                 return History.from_attribute(self, state, current)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         """Set a value on the given InstanceState.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -569,9 +572,21 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         """
         if initiator is self:
             return
-
-        # may want to add options to allow the get() here to be passive
-        old = self.get(state, dict_)
+        
+        if self.active_history:
+            old = self.get(state, dict_)
+        else:
+            # this would be the "laziest" approach,
+            # however it breaks currently expected backref
+            # behavior
+            #old = dict_.get(self.key, None)
+            # instead, use the "passive" setting, which
+            # is only going to be PASSIVE_NOCALLABLES if it
+            # came from a backref
+            old = self.get(state, dict_, passive=passive)
+            if old is PASSIVE_NORESULT:
+                old = None
+             
         value = self.fire_replace_event(state, dict_, value, old, initiator)
         dict_[self.key] = value
 
@@ -707,7 +722,7 @@ class CollectionAttributeImpl(AttributeImpl):
         else:
             collection.remove_with_event(value, initiator)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         """Set a value on the given object.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -808,6 +823,9 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
     are two objects which contain scalar references to each other.
 
     """
+    
+    active_history = False
+    
     def __init__(self, key):
         self.key = key
 
index 70243291dc3e279bdae383bbe02b7ae1b157f213..0bc7bab24ee0d0dc137450288adfc79746df0c73 100644 (file)
@@ -100,7 +100,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         dict_[self.key] = True
         return state.committed_state[self.key]
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF):
         if initiator is self:
             return
 
index c590f4323fb1dd441663966ad8b9de10f6770328..eaafe5761a43d8baed85e86f2c433ca583e36fb0 100644 (file)
@@ -783,6 +783,11 @@ class AttributeExtension(object):
 
     """
 
+    active_history = True
+    """indicates that the set() method would like to receive the 'old' value,
+    even if it means firing lazy callables.
+    """
+    
     def append(self, state, value, initiator):
         """Receive a collection append event.
 
index d155f66d12237b1daa5957d2bbdd1c531d2f9a0a..bb241b6aeba26d50d1f008731f424ef54e250cad 100644 (file)
@@ -244,7 +244,7 @@ class Mapper(object):
             else:
                 self.mapped_table = self.local_table
 
-            if self.polymorphic_identity and not self.concrete:
+            if self.polymorphic_identity is not None and not self.concrete:
                 self._identity_class = self.inherits._identity_class
             else:
                 self._identity_class = self.class_
@@ -278,7 +278,7 @@ class Mapper(object):
             self._all_tables = set()
             self.base_mapper = self
             self.mapped_table = self.local_table
-            if self.polymorphic_identity:
+            if self.polymorphic_identity is not None:
                 self.polymorphic_map[self.polymorphic_identity] = self
             self._identity_class = self.class_
 
@@ -1101,7 +1101,12 @@ class Mapper(object):
         
         """
         props = self._props
-        tables = set(props[key].parent.local_table for key in attribute_names)
+        
+        tables = set(chain(*
+                        (sqlutil.find_tables(props[key].columns[0], check_columns=True) 
+                        for key in attribute_names)
+                    ))
+        
         if self.base_mapper.local_table in tables:
             return None
 
@@ -1138,7 +1143,8 @@ class Mapper(object):
             return None
 
         cond = sql.and_(*allconds)
-        return sql.select(tables, cond, use_labels=True)
+
+        return sql.select([props[key].columns[0] for key in attribute_names], cond, use_labels=True)
 
     def cascade_iterator(self, type_, state, halt_on=None):
         """Iterate each element and its mapper in an object graph,
index 0c4bbc6eb6a5390755c17ff0d73d97786b92388b..0ea67f6e53b7bfe6bb977e06001433786263f984 100644 (file)
@@ -54,40 +54,41 @@ def _generative(*assertions):
     return generate
 
 class Query(object):
-    """Encapsulates the object-fetching operations provided by Mappers."""
-
+    """ORM-level SQL construction object."""
+    
+    _enable_eagerloads = True
+    _enable_assertions = True
+    _with_labels = False
+    _criterion = None
+    _yield_per = None
+    _lockmode = None
+    _order_by = False
+    _group_by = False
+    _having = None
+    _distinct = False
+    _offset = None
+    _limit = None
+    _statement = None
+    _joinpoint = None
+    _correlate = frozenset()
+    _populate_existing = False
+    _version_check = False
+    _autoflush = True
+    _current_path = ()
+    _only_load_props = None
+    _refresh_state = None
+    _from_obj = ()
+    _filter_aliases = None
+    _from_obj_alias = None
+    _currenttables = frozenset()
+    
     def __init__(self, entities, session=None):
         self.session = session
 
         self._with_options = []
-        self._lockmode = None
-        self._order_by = False
-        self._group_by = False
-        self._distinct = False
-        self._offset = None
-        self._limit = None
-        self._statement = None
         self._params = {}
-        self._yield_per = None
-        self._criterion = None
-        self._correlate = set()
-        self._joinpoint = None
-        self._with_labels = False
-        self._enable_eagerloads = True
-        self.__joinable_tables = None
-        self._having = None
-        self._populate_existing = False
-        self._version_check = False
-        self._autoflush = True
         self._attributes = {}
-        self._current_path = ()
-        self._only_load_props = None
-        self._refresh_state = None
-        self._from_obj = ()
         self._polymorphic_adapters = {}
-        self._filter_aliases = None
-        self._from_obj_alias = None
-        self.__currenttables = set()
         self._set_entities(entities)
 
     def _set_entities(self, entities, entity_wrapper=None):
@@ -97,9 +98,9 @@ class Query(object):
         for ent in util.to_list(entities):
             entity_wrapper(self, ent)
 
-        self.__setup_aliasizers(self._entities)
+        self._setup_aliasizers(self._entities)
 
-    def __setup_aliasizers(self, entities):
+    def _setup_aliasizers(self, entities):
         if hasattr(self, '_mapper_adapter_map'):
             # usually safe to share a single map, but copying to prevent
             # subtle leaks if end-user is reusing base query with arbitrary
@@ -114,7 +115,8 @@ class Query(object):
                     mapper, selectable, is_aliased_class = _entity_info(entity)
                     if not is_aliased_class and mapper.with_polymorphic:
                         with_polymorphic = mapper._with_polymorphic_mappers
-                        self.__mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
+                        self.__mapper_loads_polymorphically_with(mapper, 
+                                sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
                         adapter = None
                     elif is_aliased_class:
                         adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)
@@ -131,7 +133,7 @@ class Query(object):
             for m in m2.iterate_to_root():
                 self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
 
-    def __set_select_from(self, from_obj):
+    def _set_select_from(self, from_obj):
         if isinstance(from_obj, expression._SelectBaseMixin):
             from_obj = from_obj.alias()
 
@@ -142,7 +144,8 @@ class Query(object):
             self._from_obj_alias = sql_util.ColumnAdapter(from_obj, equivs)
 
     def _get_polymorphic_adapter(self, entity, selectable):
-        self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
+        self.__mapper_loads_polymorphically_with(entity.mapper, 
+                                sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
 
     def _reset_polymorphic_adapter(self, mapper):
         for m2 in mapper._with_polymorphic_mappers:
@@ -151,7 +154,7 @@ class Query(object):
                 self._polymorphic_adapters.pop(m.mapped_table, None)
                 self._polymorphic_adapters.pop(m.local_table, None)
 
-    def __reset_joinpoint(self):
+    def _reset_joinpoint(self):
         self._joinpoint = None
         self._filter_aliases = None
 
@@ -210,9 +213,17 @@ class Query(object):
             return clause
 
         if getattr(self, '_disable_orm_filtering', not orm_only):
-            return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters))
+            return visitors.replacement_traverse(
+                                clause, 
+                                {'column_collections':False}, 
+                                self.__replace_element(adapters)
+                            )
         else:
-            return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters))
+            return visitors.replacement_traverse(
+                                clause, 
+                                {'column_collections':False}, 
+                                self.__replace_orm_element(adapters)
+                            )
 
     def _entity_zero(self):
         return self._entities[0]
@@ -243,12 +254,16 @@ class Query(object):
 
     def _only_mapper_zero(self, rationale=None):
         if len(self._entities) > 1:
-            raise sa_exc.InvalidRequestError(rationale or "This operation requires a Query against a single mapper.")
+            raise sa_exc.InvalidRequestError(
+                    rationale or "This operation requires a Query against a single mapper."
+                )
         return self._mapper_zero()
 
     def _only_entity_zero(self, rationale=None):
         if len(self._entities) > 1:
-            raise sa_exc.InvalidRequestError(rationale or "This operation requires a Query against a single mapper.")
+            raise sa_exc.InvalidRequestError(
+                    rationale or "This operation requires a Query against a single mapper."
+                )
         return self._entity_zero()
 
     def _generate_mapper_zero(self):
@@ -264,7 +279,9 @@ class Query(object):
             equivs.update(ent.mapper._equivalent_columns)
         return equivs
 
-    def __no_criterion_condition(self, meth):
+    def _no_criterion_condition(self, meth):
+        if not self._enable_assertions:
+            return
         if self._criterion or self._statement or self._from_obj or \
                 self._limit is not None or self._offset is not None or \
                 self._group_by:
@@ -273,28 +290,36 @@ class Query(object):
         self._from_obj = ()
         self._statement = self._criterion = None
         self._order_by = self._group_by = self._distinct = False
-        self.__joined_tables = {}
 
-    def __no_clauseelement_condition(self, meth):
+    def _no_clauseelement_condition(self, meth):
+        if not self._enable_assertions:
+            return
         if self._order_by:
             raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth)
-        self.__no_criterion_condition(meth)
+        self._no_criterion_condition(meth)
 
-    def __no_statement_condition(self, meth):
+    def _no_statement_condition(self, meth):
+        if not self._enable_assertions:
+            return
         if self._statement:
             raise sa_exc.InvalidRequestError(
                 ("Query.%s() being called on a Query with an existing full "
                  "statement - can't apply criterion.") % meth)
 
-    def __no_limit_offset(self, meth):
+    def _no_limit_offset(self, meth):
+        if not self._enable_assertions:
+            return
         if self._limit is not None or self._offset is not None:
-            # TODO: do we want from_self() to be implicit here ?  i vote explicit for the time being
-            raise sa_exc.InvalidRequestError("Query.%s() being called on a Query which already has LIMIT or OFFSET applied. "
-            "To modify the row-limited results of a Query, call from_self() first.  Otherwise, call %s() before limit() or offset() are applied." % (meth, meth)
+            raise sa_exc.InvalidRequestError(
+                "Query.%s() being called on a Query which already has LIMIT or OFFSET applied. "
+                "To modify the row-limited results of a Query, call from_self() first.  "
+                "Otherwise, call %s() before limit() or offset() are applied." % (meth, meth)
             )
 
-
-    def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_state=None):
+    def _get_options(self, populate_existing=None, 
+                            version_check=None, 
+                            only_load_props=None, 
+                            refresh_state=None):
         if populate_existing:
             self._populate_existing = populate_existing
         if version_check:
@@ -315,7 +340,8 @@ class Query(object):
     def statement(self):
         """The full SELECT statement represented by this Query."""
 
-        return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True})
+        return self._compile_context(labels=self._with_labels).\
+                        statement._annotate({'_halt_adapt': True})
 
     def subquery(self):
         """return the full SELECT statement represented by this Query, embedded within an Alias.
@@ -358,7 +384,29 @@ class Query(object):
 
         """
         self._with_labels = True
-
+    
+    @_generative()
+    def enable_assertions(self, value):
+        """Control whether assertions are generated.
+        
+        When set to False, the returned Query will 
+        not assert its state before certain operations, 
+        including that LIMIT/OFFSET has not been applied
+        when filter() is called, no criterion exists
+        when get() is called, and no "from_statement()"
+        exists when filter()/order_by()/group_by() etc.
+        is called.  This more permissive mode is used by 
+        custom Query subclasses to specify criterion or 
+        other modifiers outside of the usual usage patterns.
+        
+        Care should be taken to ensure that the usage 
+        pattern is even possible.  A statement applied
+        by from_statement() will override any criterion
+        set by filter() or order_by(), for example.
+        
+        """
+        self._enable_assertions = value
+        
     @property
     def whereclause(self):
         """The WHERE criterion for this Query."""
@@ -375,7 +423,7 @@ class Query(object):
         """
         self._current_path = path
 
-    @_generative(__no_clauseelement_condition)
+    @_generative(_no_clauseelement_condition)
     def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None):
         """Load columns for descendant mappers of this Query's mapper.
 
@@ -438,7 +486,9 @@ class Query(object):
         if hasattr(ident, '__composite_values__'):
             ident = ident.__composite_values__()
 
-        key = self._only_mapper_zero("get() can only be used against a single mapped class.").identity_key_from_primary_key(ident)
+        key = self._only_mapper_zero(
+                    "get() can only be used against a single mapped class."
+                ).identity_key_from_primary_key(ident)
         return self._get(key, ident)
 
     @classmethod
@@ -526,7 +576,11 @@ class Query(object):
                 if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero():
                     break
             else:
-                raise sa_exc.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self._mapper_zero().class_.__name__, instance.__class__.__name__))
+                raise sa_exc.InvalidRequestError(
+                            "Could not locate a property which relates instances "
+                            "of class '%s' to instances of class '%s'" % 
+                            (self._mapper_zero().class_.__name__, instance.__class__.__name__)
+                        )
         else:
             prop = mapper.get_property(property, resolve_synonyms=True)
         return self.filter(prop.compare(operators.eq, instance, value_is_parent=True))
@@ -540,7 +594,7 @@ class Query(object):
 
         self._entities = list(self._entities)
         m = _MapperEntity(self, entity)
-        self.__setup_aliasizers([m])
+        self._setup_aliasizers([m])
 
     def from_self(self, *entities):
         """return a Query that selects from this Query's SELECT statement.
@@ -562,7 +616,7 @@ class Query(object):
         self._statement = self._criterion = None
         self._order_by = self._group_by = self._distinct = False
         self._limit = self._offset = None
-        self.__set_select_from(fromclause)
+        self._set_select_from(fromclause)
 
     def values(self, *columns):
         """Return an iterator yielding result tuples corresponding to the given list of columns"""
@@ -596,20 +650,20 @@ class Query(object):
         _ColumnEntity(self, column)
         # _ColumnEntity may add many entities if the
         # given arg is a FROM clause
-        self.__setup_aliasizers(self._entities[l:])
+        self._setup_aliasizers(self._entities[l:])
 
     def options(self, *args):
         """Return a new Query object, applying the given list of
         MapperOptions.
 
         """
-        return self.__options(False, *args)
+        return self._options(False, *args)
 
     def _conditional_options(self, *args):
-        return self.__options(True, *args)
+        return self._options(True, *args)
 
     @_generative()
-    def __options(self, conditional, *args):
+    def _options(self, conditional, *args):
         # most MapperOptions write to the '_attributes' dictionary,
         # so copy that as well
         self._attributes = self._attributes.copy()
@@ -645,7 +699,7 @@ class Query(object):
         self._params = self._params.copy()
         self._params.update(kwargs)
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     def filter(self, criterion):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``
 
@@ -674,7 +728,7 @@ class Query(object):
         return self.filter(sql.and_(*clauses))
 
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
     def order_by(self, *criterion):
         """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
@@ -689,7 +743,7 @@ class Query(object):
             else:
                 self._order_by = self._order_by + criterion
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
     def group_by(self, *criterion):
         """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
@@ -703,7 +757,7 @@ class Query(object):
         else:
             self._group_by = self._group_by + criterion
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     def having(self, criterion):
         """apply a HAVING criterion to the query and return the newly resulting ``Query``."""
 
@@ -871,7 +925,7 @@ class Query(object):
         aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
         if kwargs:
             raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys()))
-        return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
+        return self._join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
 
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
     def outerjoin(self, *props, **kwargs):
@@ -884,19 +938,19 @@ class Query(object):
         aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
         if kwargs:
             raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys()))
-        return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
+        return self._join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
 
-    @_generative(__no_statement_condition, __no_limit_offset)
-    def __join(self, keys, outerjoin, create_aliases, from_joinpoint):
+    @_generative(_no_statement_condition, _no_limit_offset)
+    def _join(self, keys, outerjoin, create_aliases, from_joinpoint):
 
         # copy collections that may mutate so they do not affect
         # the copied-from query.
-        self.__currenttables = set(self.__currenttables)
+        self._currenttables = set(self._currenttables)
         self._polymorphic_adapters = self._polymorphic_adapters.copy()
 
         # start from the beginning unless from_joinpoint is set.
         if not from_joinpoint:
-            self.__reset_joinpoint()
+            self._reset_joinpoint()
 
         clause = replace_clause_index = None
         
@@ -1031,11 +1085,11 @@ class Query(object):
 
                 elif prop:
                     # for joins across plain relation()s, try not to specify the
-                    # same joins twice.  the __currenttables collection tracks
+                    # same joins twice.  the _currenttables collection tracks
                     # what plain mapped tables we've joined to already.
 
-                    if prop.table in self.__currenttables:
-                        if prop.secondary is not None and prop.secondary not in self.__currenttables:
+                    if prop.table in self._currenttables:
+                        if prop.secondary is not None and prop.secondary not in self._currenttables:
                             # TODO: this check is not strong enough for different paths to the same endpoint which
                             # does not use secondary tables
                             raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this "
@@ -1044,8 +1098,8 @@ class Query(object):
                         continue
 
                     if prop.secondary:
-                        self.__currenttables.add(prop.secondary)
-                    self.__currenttables.add(prop.table)
+                        self._currenttables.add(prop.secondary)
+                    self._currenttables.add(prop.table)
 
                     if of_type:
                         right_entity = of_type
@@ -1101,7 +1155,7 @@ class Query(object):
         # future joins with from_joinpoint=True join from our established right_entity.
         self._joinpoint = right_entity
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def reset_joinpoint(self):
         """return a new Query reset the 'joinpoint' of this Query reset
         back to the starting mapper.  Subsequent generative calls will
@@ -1111,9 +1165,9 @@ class Query(object):
         the root.
 
         """
-        self.__reset_joinpoint()
+        self._reset_joinpoint()
 
-    @_generative(__no_clauseelement_condition)
+    @_generative(_no_clauseelement_condition)
     def select_from(self, from_obj):
         """Set the `from_obj` parameter of the query and return the newly
         resulting ``Query``.  This replaces the table which this Query selects
@@ -1134,7 +1188,7 @@ class Query(object):
             from_obj = from_obj[-1]
         if not isinstance(from_obj, expression.FromClause):
             raise sa_exc.ArgumentError("select_from() accepts FromClause objects only.")
-        self.__set_select_from(from_obj)
+        self._set_select_from(from_obj)
 
     def __getitem__(self, item):
         if isinstance(item, slice):
@@ -1157,7 +1211,7 @@ class Query(object):
         else:
             return list(self[item:item+1])[0]
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def slice(self, start, stop):
         """apply LIMIT/OFFSET to the ``Query`` based on a range and return the newly resulting ``Query``."""
         if start is not None and stop is not None:
@@ -1168,7 +1222,7 @@ class Query(object):
         elif start is not None and stop is None:
             self._offset = (self._offset or 0) + start
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def limit(self, limit):
         """Apply a ``LIMIT`` to the query and return the newly resulting
 
@@ -1177,7 +1231,7 @@ class Query(object):
         """
         self._limit = limit
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def offset(self, offset):
         """Apply an ``OFFSET`` to the query and return the newly resulting
         ``Query``.
@@ -1185,7 +1239,7 @@ class Query(object):
         """
         self._offset = offset
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def distinct(self):
         """Apply a ``DISTINCT`` to the query and return the newly resulting
         ``Query``.
@@ -1201,7 +1255,7 @@ class Query(object):
         """
         return list(self)
 
-    @_generative(__no_clauseelement_condition)
+    @_generative(_no_clauseelement_condition)
     def from_statement(self, statement):
         """Execute the given SELECT statement and return results.
 
@@ -1402,7 +1456,7 @@ class Query(object):
 
         if refresh_state is None:
             q = self._clone()
-            q.__no_criterion_condition("get")
+            q._no_criterion_condition("get")
         else:
             q = self._clone()
 
@@ -1424,7 +1478,7 @@ class Query(object):
 
         if lockmode is not None:
             q._lockmode = lockmode
-        q.__get_options(
+        q._get_options(
             populate_existing=bool(refresh_state),
             version_check=(lockmode is not None),
             only_load_props=only_load_props,
@@ -1454,18 +1508,24 @@ class Query(object):
                 kwargs.get('distinct', False))
 
     def count(self):
-        """Apply this query's criterion to a SELECT COUNT statement.
-
-        If column expressions or LIMIT/OFFSET/DISTINCT are present,
-        the query "SELECT count(1) FROM (SELECT ...)" is issued,
-        so that the result matches the total number of rows
-        this query would return.  For mapped entities,
-        the primary key columns of each is written to the
-        columns clause of the nested SELECT statement.
-
-        For a Query which is only against mapped entities,
-        a simpler "SELECT count(1) FROM table1, table2, ...
-        WHERE criterion" is issued.
+        """Return a count of rows this Query would return.
+        
+        For simple entity queries, count() issues
+        a SELECT COUNT, and will specifically count the primary
+        key column of the first entity only.  If the query uses 
+        LIMIT, OFFSET, or DISTINCT, count() will wrap the statement 
+        generated by this Query in a subquery, from which a SELECT COUNT
+        is issued, so that the contract of "how many rows
+        would be returned?" is honored.
+        
+        For queries that request specific columns or expressions, 
+        count() again makes no assumptions about those expressions
+        and will wrap everything in a subquery.  Therefore,
+        ``Query.count()`` is usually not what you want in this case.   
+        To count specific columns, often in conjunction with 
+        GROUP BY, use ``func.count()`` as an individual column expression
+        instead of ``Query.count()``.  See the ORM tutorial
+        for an example.
 
         """
         should_nest = [self._should_nest_selectable]
@@ -2027,7 +2087,9 @@ class _ColumnEntity(_QueryEntity):
                 return
 
         if not isinstance(column, sql.ColumnElement):
-            raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column)
+            raise sa_exc.InvalidRequestError(
+                "SQL expression, column, or mapped entity expected - got '%r'" % column
+            )
 
         # if the Column is unnamed, give it a
         # label() so that mutable column expressions
index ea5aae645e0523d68929baa6c383767196e2baed..d3d653de4f6c051e1c158a0813fa465451c7917e 100644 (file)
@@ -638,9 +638,9 @@ class Session(object):
         If no transaction is in progress, this method is a pass-through.
 
         This method rolls back the current transaction or nested transaction
-        regardless of subtransactions being in effect.  All subtrasactions up
+        regardless of subtransactions being in effect.  All subtransactions up
         to the first real transaction are closed.  Subtransactions occur when
-        begin() is called mulitple times.
+        begin() is called multiple times.
 
         """
         if self.transaction is None:
index d2ab214666175e88a1d13ff0ac3ab7c811b437f6..902658a0e47d1ed58488d58f5c6d6f028d3cd8a8 100644 (file)
@@ -33,7 +33,7 @@ def _register_attribute(strategy, mapper, useobject,
 
     prop = strategy.parent_property
     attribute_ext = util.to_list(prop.extension) or []
-
+        
     if useobject and prop.single_parent:
         attribute_ext.append(_SingleParentValidator(prop))
 
@@ -370,13 +370,16 @@ class LazyLoader(AbstractRelationLoader):
     def init_class_attribute(self, mapper):
         self.is_class_level = True
         
-        
+        # MANYTOONE currently only needs the "old" value for delete-orphan
+        # cascades.  the required _SingleParentValidator will enable active_history
+        # in that case.  otherwise we don't need the "old" value during backref operations.
         _register_attribute(self, 
                 mapper,
                 useobject=True,
                 callable_=self._class_level_loader,
                 uselist = self.parent_property.uselist,
                 typecallable = self.parent_property.collection_class,
+                active_history = self.parent_property.direction is not interfaces.MANYTOONE, 
                 )
 
     def lazy_clause(self, state, reverse_direction=False, alias_secondary=False, adapt_source=None):
index e2bc3fde53341bdafaa02f3893e23ffa6e0da25e..bca6b4f463da07ac973ec50ae4232491bcfc3ab4 100644 (file)
@@ -33,7 +33,9 @@ class UOWEventHandler(interfaces.AttributeExtension):
     """An event handler added to all relation attributes which handles
     session cascade operations.
     """
-
+    
+    active_history = False
+    
     def __init__(self, key):
         self.key = key
 
index 252fa8407fbce56a41ce0c098ebc2c49dcc1f39e..346bf884af9d388ba0c2a182815b132899ba6062 100644 (file)
@@ -65,14 +65,9 @@ class SchemaItem(visitors.Visitable):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
-    @property
+    @util.memoized_property
     def info(self):
-        try:
-            return self._info
-        except AttributeError:
-            self._info = {}
-            return self._info
-
+        return {}
 
 def _get_table_key(name, schema):
     if schema is None:
@@ -223,8 +218,8 @@ class Table(SchemaItem, expression.TableClause):
 
         self.quote = kwargs.pop('quote', None)
         self.quote_schema = kwargs.pop('quote_schema', None)
-        if kwargs.get('info'):
-            self._info = kwargs.pop('info')
+        if 'info' in kwargs:
+            self.info = kwargs.pop('info')
 
         self._prefixes = kwargs.pop('prefixes', [])
 
@@ -263,7 +258,7 @@ class Table(SchemaItem, expression.TableClause):
                 setattr(self, key, kwargs.pop(key))
 
         if 'info' in kwargs:
-            self._info = kwargs.pop('info')
+            self.info = kwargs.pop('info')
 
         self._extra_kwargs(**kwargs)
         self._init_items(*args)
@@ -614,8 +609,9 @@ class Column(SchemaItem, expression.ColumnClause):
 
         util.set_creation_order(self)
 
-        if kwargs.get('info'):
-            self._info = kwargs.pop('info')
+        if 'info' in kwargs:
+            self.info = kwargs.pop('info')
+            
         if kwargs:
             raise exc.ArgumentError(
                 "Unknown arguments passed to Column: " + repr(kwargs.keys()))
index 8899486546fe76d8ee4954e66ca67c60ea5d53a2..66dc84f194c5a9a4aee583f820590aa23ddc413d 100644 (file)
@@ -412,8 +412,8 @@ class SQLCompiler(engine.Compiled):
         else:
             return text
 
-    def visit_unary(self, unary, **kwargs):
-        s = self.process(unary.element)
+    def visit_unary(self, unary, **kw):
+        s = self.process(unary.element, **kw)
         if unary.operator:
             s = OPERATORS[unary.operator] + s
         if unary.modifier:
index 6c116e5bc0a38ebb5030253d9799e290a048987c..2da5184e69c34a1a6b971faabfe2b1ea0d3fad6f 100644 (file)
@@ -407,8 +407,8 @@ def not_(clause):
 
 def distinct(expr):
     """Return a ``DISTINCT`` clause."""
-
-    return _UnaryExpression(expr, operator=operators.distinct_op)
+    expr = _literal_as_binds(expr)
+    return _UnaryExpression(expr, operator=operators.distinct_op, type_=expr.type)
 
 def between(ctest, cleft, cright):
     """Return a ``BETWEEN`` predicate clause.
@@ -1568,8 +1568,7 @@ class _CompareMixin(ColumnOperators):
 
     def distinct(self):
         """Produce a DISTINCT clause, i.e. ``DISTINCT <columnname>``"""
-
-        return _UnaryExpression(self, operator=operators.distinct_op)
+        return _UnaryExpression(self, operator=operators.distinct_op, type_=self.type)
 
     def between(self, cleft, cright):
         """Produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
index f1f329b5e27a31b60b52bd022b2ab69bedb903e2..ac95c3a20950a8d45eb66dce9a6322d8cde0171a 100644 (file)
@@ -53,24 +53,21 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join
     tables = []
     _visitors = {}
     
-    def visit_something(elem):
-        tables.append(elem)
-        
     if include_selects:
-        _visitors['select'] = _visitors['compound_select'] = visit_something
+        _visitors['select'] = _visitors['compound_select'] = tables.append
     
     if include_joins:
-        _visitors['join'] = visit_something
+        _visitors['join'] = tables.append
         
     if include_aliases:
-        _visitors['alias']  = visit_something
+        _visitors['alias']  = tables.append
 
     if check_columns:
         def visit_column(column):
             tables.append(column.table)
         _visitors['column'] = visit_column
 
-    _visitors['table'] = visit_something
+    _visitors['table'] = tables.append
 
     visitors.traverse(clause, {'column_collections':False}, _visitors)
     return tables
index 2537eb695e2ce67612346ed8cf5aa7a7d73c8750..d8a541abf0529b539e58aaf7fb4ca410cfc4609e 100644 (file)
@@ -1178,6 +1178,6 @@ class BinaryTest(TestBase, AssertsExecutionResults):
 
     def load_stream(self, name, len=3000):
         f = os.path.join(os.path.dirname(__file__), "..", name)
-        return file(f).read(len)
+        return file(f, 'rb').read(len)
 
 
index 9f753039a54ffe401e267f09e41f0f226cf69336..784a7b9ce619a8cbdd3b6726caef33ab9c34eb2b 100644 (file)
@@ -162,3 +162,29 @@ class TableOptionsTest(TestBase, AssertsCompiledSQL):
           "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)"
         )
 
+    def test_table_info(self):
+        metadata = MetaData()
+        t1 = Table('foo', metadata, info={'x':'y'})
+        t2 = Table('bar', metadata, info={})
+        t3 = Table('bat', metadata)
+        assert t1.info == {'x':'y'}
+        assert t2.info == {}
+        assert t3.info == {}
+        for t in (t1, t2, t3):
+            t.info['bar'] = 'zip'
+            assert t.info['bar'] == 'zip'
+
+class ColumnOptionsTest(TestBase):
+    def test_column_info(self):
+        
+        c1 = Column('foo', info={'x':'y'})
+        c2 = Column('bar', info={})
+        c3 = Column('bat')
+        assert c1.info == {'x':'y'}
+        assert c2.info == {}
+        assert c3.info == {}
+        
+        for c in (c1, c2, c3):
+            c.info['bar'] = 'zip'
+            assert c.info['bar'] == 'zip'
+
index 7ef749c96531acff65c4b66ad9992b1c0cbe58e0..4a5775218dcaa5f1e67a88aaf61e821e9df54e3e 100644 (file)
@@ -1,4 +1,7 @@
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test.testing import eq_, assert_raises
+import copy
+import pickle
+
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm.collections import collection
@@ -15,12 +18,15 @@ class DictCollection(dict):
     def remove(self, obj):
         del self[obj.foo]
 
+
 class SetCollection(set):
     pass
 
+
 class ListCollection(list):
     pass
 
+
 class ObjectCollection(object):
     def __init__(self):
         self.values = list()
@@ -33,6 +39,21 @@ class ObjectCollection(object):
     def __iter__(self):
         return iter(self.values)
 
+
+class Parent(object):
+    kids = association_proxy('children', 'name')
+    def __init__(self, name):
+        self.name = name
+
+class Child(object):
+    def __init__(self, name):
+        self.name = name
+
+class KVChild(object):
+    def __init__(self, name, value):
+        self.name = name
+        self.value = value
+
 class _CollectionOperations(TestBase):
     def setup(self):
         collection_class = self.collection_class
@@ -837,29 +858,23 @@ class ReconstitutionTest(TestBase):
         metadata.create_all()
         parents.insert().execute(name='p1')
 
-        class Parent(object):
-            kids = association_proxy('children', 'name')
-            def __init__(self, name):
-                self.name = name
-
-        class Child(object):
-            def __init__(self, name):
-                self.name = name
-
-        mapper(Parent, parents, properties=dict(children=relation(Child)))
-        mapper(Child, children)
 
         self.metadata = metadata
-        self.Parent = Parent
-
+        self.parents = parents
+        self.children = children
+        
     def teardown(self):
         self.metadata.drop_all()
+        clear_mappers()
 
     def test_weak_identity_map(self):
+        mapper(Parent, self.parents, properties=dict(children=relation(Child)))
+        mapper(Child, self.children)
+
         session = create_session(weak_identity_map=True)
 
         def add_child(parent_name, child_name):
-            parent = (session.query(self.Parent).
+            parent = (session.query(Parent).
                       filter_by(name=parent_name)).one()
             parent.kids.append(child_name)
 
@@ -869,12 +884,14 @@ class ReconstitutionTest(TestBase):
         add_child('p1', 'c2')
 
         session.flush()
-        p = session.query(self.Parent).filter_by(name='p1').one()
+        p = session.query(Parent).filter_by(name='p1').one()
         assert set(p.kids) == set(['c1', 'c2']), p.kids
 
     def test_copy(self):
-        import copy
-        p = self.Parent('p1')
+        mapper(Parent, self.parents, properties=dict(children=relation(Child)))
+        mapper(Child, self.children)
+
+        p = Parent('p1')
         p.kids.extend(['c1', 'c2'])
         p_copy = copy.copy(p)
         del p
@@ -882,4 +899,51 @@ class ReconstitutionTest(TestBase):
 
         assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
 
+    def test_pickle_list(self):
+        mapper(Parent, self.parents, properties=dict(children=relation(Child)))
+        mapper(Child, self.children)
+
+        p = Parent('p1')
+        p.kids.extend(['c1', 'c2'])
+
+        r1 = pickle.loads(pickle.dumps(p))
+        assert r1.kids == ['c1', 'c2']
 
+        r2 = pickle.loads(pickle.dumps(p.kids))
+        assert r2 == ['c1', 'c2']
+
+    def test_pickle_set(self):
+        mapper(Parent, self.parents, properties=dict(children=relation(Child, collection_class=set)))
+        mapper(Child, self.children)
+
+        p = Parent('p1')
+        p.kids.update(['c1', 'c2'])
+
+        r1 = pickle.loads(pickle.dumps(p))
+        assert r1.kids == set(['c1', 'c2'])
+
+        r2 = pickle.loads(pickle.dumps(p.kids))
+        assert r2 == set(['c1', 'c2'])
+
+    def test_pickle_dict(self):
+        mapper(Parent, self.parents, properties=dict(
+                    children=relation(KVChild, collection_class=collections.mapped_collection(PickleKeyFunc('name')))
+                ))
+        mapper(KVChild, self.children)
+
+        p = Parent('p1')
+        p.kids.update({'c1':'v1', 'c2':'v2'})
+        assert p.kids == {'c1':'c1', 'c2':'c2'}
+        
+        r1 = pickle.loads(pickle.dumps(p))
+        assert r1.kids == {'c1':'c1', 'c2':'c2'}
+
+        r2 = pickle.loads(pickle.dumps(p.kids))
+        assert r2 == {'c1':'c1', 'c2':'c2'}
+
+class PickleKeyFunc(object):
+    def __init__(self, name):
+        self.name = name
+    
+    def __call__(self, obj):
+        return getattr(obj, self.name)
\ No newline at end of file
index 6bf709dfc74e7b76f0a7c0550547de5acbe9f3c6..9ca8356918922e68eeeefafb4ff758dc13598f8b 100644 (file)
@@ -449,6 +449,15 @@ class DeclarativeTest(DeclarativeTestBase):
             define)
         
     def test_table_args(self):
+        
+        def err():
+            class Foo(Base):
+                __tablename__ = 'foo'
+                __table_args__ = (ForeignKeyConstraint(['id'], ['foo.id']),)
+                id = Column('id', Integer, primary_key=True)
+                
+        assert_raises_message(sa.exc.ArgumentError, "Tuple form of __table_args__ is ", err)
+        
         class Foo(Base):
             __tablename__ = 'foo'
             __table_args__ = {'mysql_engine':'InnoDB'}
index 435f26cbaee2b8c24056c350b057bfe5ecd5690c..e9cd6093d2d57048f85fd265d68e576458dc73bc 100644 (file)
@@ -74,20 +74,33 @@ class FalseDiscriminatorTest(_base.MappedTest):
         global t1
         t1 = Table('t1', metadata, 
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True), 
-            Column('type', Integer, nullable=False))
+            Column('type', Boolean, nullable=False))
         
-    def test_false_discriminator(self):
+    def test_false_on_sub(self):
         class Foo(object):pass
         class Bar(Foo):pass
-        mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=1)
-        mapper(Bar, inherits=Foo, polymorphic_identity=0)
+        mapper(Foo, t1, polymorphic_on=t1.c.type, polymorphic_identity=True)
+        mapper(Bar, inherits=Foo, polymorphic_identity=False)
         sess = create_session()
-        f1 = Bar()
-        sess.add(f1)
+        b1 = Bar()
+        sess.add(b1)
         sess.flush()
-        assert f1.type == 0
+        assert b1.type is False
         sess.expunge_all()
         assert isinstance(sess.query(Foo).one(), Bar)
+
+    def test_false_on_base(self):
+        class Ding(object):pass
+        class Bat(Ding):pass
+        mapper(Ding, t1, polymorphic_on=t1.c.type, polymorphic_identity=False)
+        mapper(Bat, inherits=Ding, polymorphic_identity=True)
+        sess = create_session()
+        d1 = Ding()
+        sess.add(d1)
+        sess.flush()
+        assert d1.type is False
+        sess.expunge_all()
+        assert sess.query(Ding).one() is not None
         
 class PolymorphicSynonymTest(_base.MappedTest):
     @classmethod
@@ -900,10 +913,8 @@ class OverrideColKeyTest(_base.MappedTest):
         assert sess.query(Sub).get(s1.base_id).data == "this is base"
 
 class OptimizedLoadTest(_base.MappedTest):
-    """test that the 'optimized load' routine doesn't crash when 
-    a column in the join condition is not available.
+    """tests for the "optimized load" routine."""
     
-    """
     @classmethod
     def define_tables(cls, metadata):
         global base, sub
@@ -918,7 +929,10 @@ class OptimizedLoadTest(_base.MappedTest):
         )
     
     def test_optimized_passes(self):
-        class Base(object):
+        """"test that the 'optimized load' routine doesn't crash when 
+        a column in the join condition is not available."""
+        
+        class Base(_base.BasicEntity):
             pass
         class Sub(Base):
             pass
@@ -928,21 +942,66 @@ class OptimizedLoadTest(_base.MappedTest):
         # redefine Sub's "id" to favor the "id" col in the subtable.
         # "id" is also part of the primary join condition
         mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={'id':sub.c.id})
-        sess = create_session()
-        s1 = Sub()
-        s1.data = 's1data'
-        s1.sub = 's1sub'
+        sess = sessionmaker()()
+        s1 = Sub(data='s1data', sub='s1sub')
         sess.add(s1)
-        sess.flush()
+        sess.commit()
         sess.expunge_all()
         
         # load s1 via Base.  s1.id won't populate since it's relative to 
         # the "sub" table.  The optimized load kicks in and tries to 
         # generate on the primary join, but cannot since "id" is itself unloaded.
         # the optimized load needs to return "None" so regular full-row loading proceeds
-        s1 = sess.query(Base).get(s1.id)
+        s1 = sess.query(Base).first()
         assert s1.sub == 's1sub'
 
+    def test_column_expression(self):
+        class Base(_base.BasicEntity):
+            pass
+        class Sub(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
+            'concat':column_property(sub.c.sub + "|" + sub.c.sub)
+        })
+        sess = sessionmaker()()
+        s1 = Sub(data='s1data', sub='s1sub')
+        sess.add(s1)
+        sess.commit()
+        sess.expunge_all()
+        s1 = sess.query(Base).first()
+        assert s1.concat == 's1sub|s1sub'
+
+    def test_column_expression_joined(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub', properties={
+            'concat':column_property(base.c.data + "|" + sub.c.sub)
+        })
+        sess = sessionmaker()()
+        s1 = Sub(data='s1data', sub='s1sub')
+        s2 = Sub(data='s2data', sub='s2sub')
+        s3 = Sub(data='s3data', sub='s3sub')
+        sess.add_all([s1, s2, s3])
+        sess.commit()
+        sess.expunge_all()
+        # query a bunch of rows to ensure there's no cartesian
+        # product against "base" occurring, it is in fact
+        # detecting that "base" needs to be in the join 
+        # criterion
+        eq_(
+            sess.query(Base).order_by(Base.id).all(),
+            [
+                Sub(data='s1data', sub='s1sub', concat='s1data|s1sub'),
+                Sub(data='s2data', sub='s2sub', concat='s2data|s2sub'),
+                Sub(data='s3data', sub='s3sub', concat='s3data|s3sub')
+            ]
+        )
+        
+        
 class PKDiscriminatorTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
diff --git a/test/orm/test_backref_mutations.py b/test/orm/test_backref_mutations.py
new file mode 100644 (file)
index 0000000..1ecf027
--- /dev/null
@@ -0,0 +1,474 @@
+"""
+a series of tests which assert the behavior of moving objects between collections
+and scalar attributes resulting in the expected state w.r.t. backrefs, add/remove
+events, etc.
+
+there's a particular focus on collections that have "uselist=False", since in these
+cases the re-assignment of an attribute means the previous owner needs an
+UPDATE in the database.
+
+"""
+
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref, sessionmaker
+from sqlalchemy.orm import attributes, exc as orm_exc
+from sqlalchemy.test import testing
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
+
+class O2MCollectionTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = dict(
+            addresses = relation(Address, backref="user"),
+        ))
+
+    @testing.resolve_artifact_names
+    def test_collection_move_hitslazy(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address2")
+        a3 = Address(email_address="address3")
+        u1= User(name='jack', addresses=[a1, a2, a3])
+        u2= User(name='ed')
+        sess.add_all([u1, a1, a2, a3])
+        sess.commit()
+        
+        #u1.addresses
+        
+        def go():
+            u2.addresses.append(a1)
+            u2.addresses.append(a2)
+            u2.addresses.append(a3)
+        self.assert_sql_count(testing.db, go, 0)
+        
+    @testing.resolve_artifact_names
+    def test_collection_move_preloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', addresses=[a1])
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.addresses collection
+        u1.addresses
+
+        u2.addresses.append(a1)
+
+        # backref fires
+        assert a1.user is u2
+
+        # doesn't extend to the previous collection tho,
+        # which was already loaded.
+        # flushing at this point means its anyone's guess.
+        assert a1 in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_collection_move_notloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', addresses=[a1])
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        u2.addresses.append(a1)
+
+        # backref fires
+        assert a1.user is u2
+        
+        # u1.addresses wasn't loaded,
+        # so when it loads its correct
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', addresses=[a1])
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.addresses collection
+        u1.addresses
+
+        u2.addresses.append(a1)
+
+        # backref fires
+        assert a1.user is u2
+        
+        # everything expires, no changes in 
+        # u1.addresses, so all is fine
+        sess.commit()
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_preloaded(self):
+        sess = sessionmaker()()
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        a1.user = u1
+        sess.add_all([u1, u2, a1])
+        sess.commit()
+
+        # u1.addresses is loaded
+        u1.addresses
+
+        # direct set - the fetching of the 
+        # "old" u1 here allows the backref
+        # to remove it from the addresses collection
+        a1.user = u2
+
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+        
+    @testing.resolve_artifact_names
+    def test_scalar_move_notloaded(self):
+        sess = sessionmaker()()
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        a1.user = u1
+        sess.add_all([u1, u2, a1])
+        sess.commit()
+
+        # direct set - the fetching of the 
+        # "old" u1 here allows the backref
+        # to remove it from the addresses collection
+        a1.user = u2
+
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_commitfirst(self):
+        sess = sessionmaker()()
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        a1.user = u1
+        sess.add_all([u1, u2, a1])
+        sess.commit()
+
+        # u1.addresses is loaded
+        u1.addresses
+
+        # direct set - the fetching of the 
+        # "old" u1 here allows the backref
+        # to remove it from the addresses collection
+        a1.user = u2
+        
+        sess.commit()
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+class O2OScalarBackrefMoveTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = {
+            'address':relation(Address, backref=backref("user"), uselist=False)
+        })
+
+    @testing.resolve_artifact_names
+    def test_collection_move_preloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.address
+        u1.address
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # backref fires
+        assert a1.user is u2
+
+        # doesn't extend to the previous attribute tho.
+        # flushing at this point means its anyone's guess.
+        assert u1.address is a1
+        assert u2.address is a1
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_preloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        sess.add_all([u1, a1, a2])
+        sess.commit() # everything is expired
+
+        # load a1.user
+        a1.user
+        
+        # reassign
+        a2.user = u1
+
+        # backref fires
+        assert u1.address is a2
+        
+        # stays on both sides
+        assert a1.user is u1
+        assert a2.user is u1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_notloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # backref fires
+        assert a1.user is u2
+        
+        # u1.address loads now after a flush
+        assert u1.address is None
+        assert u2.address is a1
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_notloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        sess.add_all([u1, a1, a2])
+        sess.commit() # everything is expired
+
+        # reassign
+        a2.user = u1
+
+        # backref fires
+        assert u1.address is a2
+
+        # stays on both sides
+        assert a1.user is u1
+        assert a2.user is u1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.address
+        u1.address
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # backref fires
+        assert a1.user is u2
+
+        # the commit cancels out u1.addresses
+        # being loaded, on next access its fine.
+        sess.commit()
+        assert u1.address is None
+        assert u2.address is a1
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address2")
+        u1 = User(name='jack', address=a1)
+
+        sess.add_all([u1, a1, a2])
+        sess.commit() # everything is expired
+
+        # load
+        assert a1.user is u1
+        
+        # reassign
+        a2.user = u1
+
+        # backref fires
+        assert u1.address is a2
+
+        # didnt work this way tho
+        assert a1.user is u1
+        
+        # moves appropriately after commit
+        sess.commit()
+        assert u1.address is a2
+        assert a1.user is None
+        assert a2.user is u1
+
+class O2OScalarMoveTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = {
+            'address':relation(Address, uselist=False)
+        })
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.address
+        u1.address
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # the commit cancels out u1.addresses
+        # being loaded, on next access its fine.
+        sess.commit()
+        assert u1.address is None
+        assert u2.address is a1
+
+class O2OScalarOrphanTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = {
+            'address':relation(Address, uselist=False, 
+                backref=backref('user', single_parent=True, cascade="all, delete-orphan"))
+        })
+
+    @testing.resolve_artifact_names
+    def test_m2o_event(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+        
+        sess.add(u1)
+        sess.commit()
+        sess.expunge(u1)
+        
+        u2= User(name='ed')
+        # the _SingleParent extension sets the backref get to "active" !
+        # u1 gets loaded and deleted
+        u2.address = a1
+        sess.commit()
+        assert sess.query(User).count() == 1
+        
+    
+class M2MScalarMoveTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Item, items, properties={
+            'keyword':relation(Keyword, secondary=item_keywords, uselist=False, backref=backref("item", uselist=False))
+        })
+        mapper(Keyword, keywords)
+    
+    @testing.resolve_artifact_names
+    def test_collection_move_preloaded(self):
+        sess = sessionmaker()()
+        
+        k1 = Keyword(name='k1')
+        i1 = Item(description='i1', keyword=k1)
+        i2 = Item(description='i2')
+
+        sess.add_all([i1, i2, k1])
+        sess.commit() # everything is expired
+        
+        # load i1.keyword
+        assert i1.keyword is k1
+        
+        i2.keyword = k1
+
+        assert k1.item is i2
+        
+        # nothing happens.
+        assert i1.keyword is k1
+        assert i2.keyword is k1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_notloaded(self):
+        sess = sessionmaker()()
+
+        k1 = Keyword(name='k1')
+        i1 = Item(description='i1', keyword=k1)
+        i2 = Item(description='i2')
+
+        sess.add_all([i1, i2, k1])
+        sess.commit() # everything is expired
+
+        i2.keyword = k1
+
+        assert k1.item is i2
+
+        assert i1.keyword is None
+        assert i2.keyword is k1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commit(self):
+        sess = sessionmaker()()
+
+        k1 = Keyword(name='k1')
+        i1 = Item(description='i1', keyword=k1)
+        i2 = Item(description='i2')
+
+        sess.add_all([i1, i2, k1])
+        sess.commit() # everything is expired
+
+        # load i1.keyword
+        assert i1.keyword is k1
+
+        i2.keyword = k1
+
+        assert k1.item is i2
+
+        sess.commit()
+        assert i1.keyword is None
+        assert i2.keyword is k1
index 4ce36f1ff59f3c0a50145923e5ec8f607c80f020..cb60bef69a54f0a03a0a7a8b1b29647284b36df3 100644 (file)
@@ -225,6 +225,11 @@ class InvalidGenerationsTest(QueryTest):
 
             assert_raises(sa_exc.InvalidRequestError, q.having, 'foo')
     
+            q.enable_assertions(False).join("addresses")
+            q.enable_assertions(False).filter(User.name=='ed')
+            q.enable_assertions(False).order_by('foo')
+            q.enable_assertions(False).group_by('foo')
+            
     def test_no_from(self):
         s = create_session()
     
@@ -236,6 +241,10 @@ class InvalidGenerationsTest(QueryTest):
         
         q = s.query(User).order_by(User.id)
         assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
+
+        assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
+        
+        q.enable_assertions(False).select_from(users)
         
         # this is fine, however
         q.from_self()
index 9d3d785f95831df762bfd231edc364488fa24217..bbc399aa6ba38846130ea35ce7d4a0fb90234594 100644 (file)
@@ -481,6 +481,11 @@ class QueryTest(TestBase):
         self.assert_(r['query_users.user_id']) == 1
         self.assert_(r['query_users.user_name']) == "john"
 
+        # unary experssions
+        r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().first()
+        eq_(r[users.c.user_name], 'jack')
+        eq_(r.user_name, 'jack')
+
     def test_result_case_sensitivity(self):
         """test name normalization for result sets."""
         
@@ -493,6 +498,7 @@ class QueryTest(TestBase):
         
         assert row.keys() == ["case_insensitive", "CaseSensitive"]
 
+        
     def test_row_as_args(self):
         users.insert().execute(user_id=1, user_name='john')
         r = users.select(users.c.user_id==1).execute().first()
index 841dda53fe5cfbdb88ec07ad100726ca2429c0a4..9c90549e29fd12acdd4eb661edba2b8d58d527f1 100644 (file)
@@ -496,6 +496,16 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
         # this one relies upon anonymous labeling to assemble result
         # processing rules on the column.
         assert testing.db.execute(select([expr])).scalar() == -15
+        
+    def test_distinct(self):
+        s = select([distinct(test_table.c.avalue)])
+        eq_(testing.db.execute(s).scalar(), 25)
+
+        s = select([test_table.c.avalue.distinct()])
+        eq_(testing.db.execute(s).scalar(), 25)
+
+        assert distinct(test_table.c.data).type == test_table.c.data.type
+        assert test_table.c.data.distinct().type == test_table.c.data.type
 
 class DateTest(TestBase, AssertsExecutionResults):
     @classmethod