]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged -r5727:5797 of trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Feb 2009 23:02:33 +0000 (23:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Feb 2009 23:02:33 +0000 (23:02 +0000)
- newest pg8000 handles unicode statements correctly.

29 files changed:
CHANGES
examples/custom_attributes/listen_for_events.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/postgres/pg8000.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/evaluator.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/util.py
test/dialect/mssql.py
test/dialect/postgres.py
test/orm/cascade.py
test/orm/expire.py
test/orm/inheritance/query.py
test/orm/query.py
test/orm/relationships.py
test/profiling/memusage.py
test/sql/functions.py
test/sql/generative.py
test/sql/labels.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 36c8398c0ac7b08c1801a681de998e3230d503cd..1603ce65946e8552d447141a4cf0a3365995e951 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,6 +3,77 @@
 =======
 CHANGES
 =======
+0.5.3
+=====
+- orm
+    - Query now implements __clause_element__() which produces
+      its selectable, which means a Query instance can be accepted 
+      in many SQL expressions, including col.in_(query), 
+      union(query1, query2), select([foo]).select_from(query), 
+      etc.
+
+    - a session.expire() on a particular collection attribute
+      will clear any pending backref additions as well, so that
+      the next access correctly returns only what was present
+      in the database.  Presents some degree of a workaround for 
+      [ticket:1315], although we are considering removing the 
+      flush([objects]) feature altogether.
+      
+    - improvements to the "determine direction" logic of 
+      relation() such that the direction of tricky situations 
+      like mapper(A.join(B)) -> relation-> mapper(B) can be
+      determined.
+      
+    - When flushing partial sets of objects using session.flush([somelist]),
+      pending objects which remain pending after the operation won't
+      inadvertently be added as persistent. [ticket:1306]
+     
+    - Added "post_configure_attribute" method to InstrumentationManager,
+      so that the "listen_for_events.py" example works again.
+      [ticket:1314]
+      
+    - Fixed bugs in Query regarding simultaneous selection of 
+      multiple joined-table inheritance entities with common base 
+      classes:
+      
+      - previously the adaption applied to "B" on 
+        "A JOIN B" would be erroneously partially applied 
+        to "A".
+      
+      - comparisons on relations (i.e. A.related==someb)
+        were not getting adapted when they should.
+      
+      - Other filterings, like 
+        query(A).join(A.bs).filter(B.foo=='bar'), were erroneously
+        adapting "B.foo" as though it were an "A".
+      
+- sql
+    - Fixed missing _label attribute on Function object, others
+      when used in a select() with use_labels (such as when used
+      in an ORM column_property()).  [ticket:1302]
+
+    - anonymous alias names now truncate down to the max length
+      allowed by the dialect.  More significant on DBs like
+      Oracle with very small character limits. [ticket:1309]
+      
+    - the __selectable__() interface has been replaced entirely
+      by __clause_element__().
+
+    - The per-dialect cache used by TypeEngine to cache
+      dialect-specific types is now a WeakKeyDictionary.
+      This to prevent dialect objects from 
+      being referenced forever for an application that 
+      creates an arbitrarily large number of engines
+      or dialects.   There is a small performance penalty
+      which will be resolved in 0.6.  [ticket:1299]
+
+- postgres
+    - Index reflection won't fail when an index with 
+      multiple expressions is encountered.
+
+- mssql
+    - Preliminary support for pymssql 1.0.1
+    
 0.5.2
 ======
 
index c028e0fb48220d530b3d4be12fc6f913f4fece73..de28df5b3a01e76244d45465c5cd83402c4a9893 100644 (file)
@@ -7,11 +7,10 @@ across the board.
 from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager
 
 class InstallListeners(InstrumentationManager):
-    def instrument_attribute(self, class_, key, inst):
+    def post_configure_attribute(self, class_, key, inst):
         """Add an event listener to an InstrumentedAttribute."""
         
         inst.impl.extensions.insert(0, AttributeListener(key))
-        return super(InstallListeners, self).instrument_attribute(class_, key, inst)
         
 class AttributeListener(AttributeExtension):
     """Generic event listener.  
index 7db0dd882333ee9e33f92db624731d153f1f5efb..705778cc5b16576cdba08b6f3b20dc60543d1493 100644 (file)
@@ -823,10 +823,11 @@ class PGDialect(default.DefaultDialect):
         sv_idx_name = None
         for row in c.fetchall():
             idx_name, unique, expr, prd, col = row
-            if expr and not idx_name == sv_idx_name:
-                util.warn(
-                  "Skipped unsupported reflection of expression-based index %s"
-                  % idx_name)
+            if expr:
+                if idx_name != sv_idx_name:
+                    util.warn(
+                      "Skipped unsupported reflection of expression-based index %s"
+                      % idx_name)
                 sv_idx_name = idx_name
                 continue
             if prd and not idx_name == sv_idx_name:
index 00636dbfe46d70541b27cd6a7349cd18a5ec4f02..47ccab3f8bffb112c71c9e2ea411b3e797b4040b 100644 (file)
@@ -45,7 +45,7 @@ class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext):
 class Postgres_pg8000(PGDialect):
     driver = 'pg8000'
 
-    supports_unicode_statements = False
+    supports_unicode_statements = True
     
     supports_unicode_binds = True
     
index 37b9d8fa89cc7693a127926f4320ba4fe0bbff13..f2754793b0e53a18ee806667823717abe6d69d8e 100644 (file)
@@ -397,7 +397,7 @@ class SelectableClassType(type):
     def update(cls, whereclause=None, values=None, **kwargs):
         _ddl_error(cls)
 
-    def __selectable__(cls):
+    def __clause_element__(cls):
         return cls._table
 
     def __getattr__(cls, attr):
@@ -442,7 +442,7 @@ def _selectable_name(selectable):
         return x
 
 def class_for_table(selectable, **mapper_kwargs):
-    selectable = expression._selectable(selectable)
+    selectable = expression._clause_element_as_expr(selectable)
     mapname = 'Mapped' + _selectable_name(selectable)
     if isinstance(mapname, unicode): 
         engine_encoding = selectable.metadata.bind.dialect.encoding 
@@ -531,7 +531,7 @@ class SqlSoup:
 
     def with_labels(self, item):
         # TODO give meaningful aliases
-        return self.map(expression._selectable(item).select(use_labels=True).alias('foo'))
+        return self.map(expression._clause_element_as_expr(item).select(use_labels=True).alias('foo'))
 
     def join(self, *args, **kwargs):
         j = join(*args, **kwargs)
index 446c55b41ecb107359a06b22f3a8f8e8da996969..e3901f9b1057b1f960a60fd710fece803599d869 100644 (file)
@@ -758,6 +758,7 @@ class CollectionAttributeImpl(AttributeImpl):
         state.commit([self.key])
 
         if self.key in state.pending:
+            
             # pending items exist.  issue a modified event,
             # add/remove new items.
             state.modified_event(self, True, user_data)
@@ -1027,6 +1028,7 @@ class InstanceState(object):
                 if impl.accepts_scalar_loader:
                     self.callables[key] = self
             self.dict.pop(key, None)
+            self.pending.pop(key, None)
             self.committed_state.pop(key, None)
 
     def reset(self, key):
@@ -1201,6 +1203,9 @@ class ClassManager(dict):
                 manager = create_manager_for_cls(cls)
             manager.instrument_attribute(key, inst, True)
 
+    def post_configure_attribute(self, key):
+        pass
+        
     def uninstrument_attribute(self, key, propagated=False):
         if key not in self:
             return
@@ -1354,6 +1359,9 @@ class _ClassInstrumentationAdapter(ClassManager):
         if not propagated:
             self._adapted.instrument_attribute(self.class_, key, inst)
 
+    def post_configure_attribute(self, key):
+        self._adapted.post_configure_attribute(self.class_, key, self[key])
+
     def install_descriptor(self, key, inst):
         self._adapted.install_descriptor(self.class_, key, inst)
 
@@ -1579,9 +1587,10 @@ def register_attribute_impl(class_, key, **kw):
             key, factory or list)
     else:
         typecallable = kw.pop('typecallable', None)
-
+        
     manager[key].impl = _create_prop(class_, key, manager, typecallable=typecallable, **kw)
-
+    manager.post_configure_attribute(key)
+    
 def register_descriptor(class_, key, proxy_property=None, comparator=None, parententity=None, property_=None):
     manager = manager_of_class(class_)
 
index 8628c8a239d84ac2d50bf34dd2cda301a72d0f64..e78891353130971f7e6d219d7f6f9f44446717aa 100644 (file)
@@ -34,8 +34,8 @@ class EvaluatorCompiler(object):
         return lambda obj: None
 
     def visit_column(self, clause):
-        if 'parententity' in clause._annotations:
-            key = clause._annotations['parententity']._get_col_to_prop(clause).key
+        if 'parentmapper' in clause._annotations:
+            key = clause._annotations['parentmapper']._get_col_to_prop(clause).key
         else:
             key = clause.key
         get_corresponding_attr = operator.attrgetter(key)
index 6c3c3b1baeb969791306222784b455ada4da9f01..3b7507def64cb96b0ec3c2253454eed606b00a72 100644 (file)
@@ -8,7 +8,7 @@
 
 Semi-private module containing various base classes used throughout the ORM.
 
-Defines the extension classes :class:`MapperExtension`, 
+Defines the extension classes :class:`MapperExtension`,
 :class:`SessionExtension`, and :class:`AttributeExtension` as
 well as other user-subclassable extension objects.
 
@@ -167,7 +167,7 @@ class MapperExtension(object):
         ``__new__``, and after initial attribute population has
         occurred.
 
-        This typicically occurs when the instance is created based on
+        This typically occurs when the instance is created based on
         incoming result rows, and is only called once for that
         instance's lifetime.
 
@@ -325,7 +325,7 @@ class SessionExtension(object):
         `query_context` was the query context object.
         `result` is the result object returned from the bulk operation.
         """
-    
+
 class MapperProperty(object):
     """Manage the relationship of a ``Mapper`` to a single class
     attribute, as well as that attribute as it appears on individual
@@ -394,36 +394,36 @@ class MapperProperty(object):
 
     def instrument_class(self, mapper):
         raise NotImplementedError()
-        
+
     _compile_started = False
     _compile_finished = False
-    
+
     def init(self):
         """Called after all mappers are created to assemble
         relationships between mappers and perform other post-mapper-creation
-        initialization steps.  
-        
+        initialization steps.
+
         """
         self._compile_started = True
         self.do_init()
         self._compile_finished = True
-        
+
     def do_init(self):
         """Perform subclass-specific initialization post-mapper-creation steps.
 
         This is a *template* method called by the
         ``MapperProperty`` object's init() method.
-        
+
         """
         pass
-    
+
     def post_instrument_class(self, mapper):
         """Perform instrumentation adjustments that need to occur
         after init() has completed.
-        
+
         """
         pass
-        
+
     def register_dependencies(self, *args, **kwargs):
         """Called by the ``Mapper`` in response to the UnitOfWork
         calling the ``Mapper``'s register_dependencies operation.
@@ -482,10 +482,10 @@ class PropComparator(expression.ColumnOperators):
     def adapted(self, adapter):
         """Return a copy of this PropComparator which will use the given adaption function
         on the local side of generated expressions.
-        
+
         """
         return self.__class__(self.prop, self.mapper, adapter)
-        
+
     @staticmethod
     def any_op(a, b, **kwargs):
         return a.any(b, **kwargs)
@@ -589,7 +589,7 @@ class StrategizedProperty(MapperProperty):
     def post_instrument_class(self, mapper):
         if self.is_primary():
             self.strategy.init_class_attribute(mapper)
-                
+
 def build_path(entity, key, prev=None):
     if prev:
         return prev + (entity, key)
@@ -738,35 +738,35 @@ class PropertyOption(MapperOption):
 
 class AttributeExtension(object):
     """An event handler for individual attribute change events.
-    
-    AttributeExtension is assembled within the descriptors associated 
-    with a mapped class. 
-    
+
+    AttributeExtension is assembled within the descriptors associated
+    with a mapped class.
+
     """
 
     def append(self, state, value, initiator):
         """Receive a collection append event.
-        
+
         The returned value will be used as the actual value to be
         appended.
-        
+
         """
         return value
 
     def remove(self, state, value, initiator):
         """Receive a remove event.
-        
+
         No return value is defined.
-        
+
         """
         pass
 
     def set(self, state, value, oldvalue, initiator):
         """Receive a set event.
-        
+
         The returned value will be used as the actual value to be
         set.
-        
+
         """
         return value
 
@@ -855,7 +855,12 @@ class LoaderStrategy(object):
             return fn
 
 class InstrumentationManager(object):
-    """User-defined class instrumentation extension."""
+    """User-defined class instrumentation extension.
+    
+    The API for this class should be considered as semi-stable,
+    and may change slightly with new releases.
+    
+    """
 
     # r4361 added a mandatory (cls) constructor to this interface.
     # given that, perhaps class_ should be dropped from all of these
@@ -878,6 +883,9 @@ class InstrumentationManager(object):
     def instrument_attribute(self, class_, key, inst):
         pass
 
+    def post_configure_attribute(self, class_, key, inst):
+        pass
+
     def install_descriptor(self, class_, key, inst):
         setattr(class_, key, inst)
 
index f05613f5c0bad899eb75cbb294030757e37ebdd8..2a772dcac244a7d4a2b78da93b05fb8f21fdbea7 100644 (file)
@@ -121,7 +121,7 @@ class ColumnProperty(StrategizedProperty):
             if self.adapter:
                 return self.adapter(self.prop.columns[0])
             else:
-                return self.prop.columns[0]._annotate({"parententity": self.mapper})
+                return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper})
                 
         def operate(self, op, *other, **kwargs):
             return op(self.__clause_element__(), *other, **kwargs)
@@ -417,7 +417,7 @@ class RelationProperty(StrategizedProperty):
             if backref:
                 raise sa_exc.ArgumentError("backref and back_populates keyword arguments are mutually exclusive")
             self.backref = None
-        elif isinstance(backref, str):
+        elif isinstance(backref, basestring):
             # propagate explicitly sent primary/secondary join conditions to the BackRef object if
             # just a string was sent
             if secondary is not None:
@@ -485,11 +485,11 @@ class RelationProperty(StrategizedProperty):
                 if self.property.direction in [ONETOMANY, MANYTOMANY]:
                     return ~self._criterion_exists()
                 else:
-                    return self.property._optimized_compare(None, adapt_source=self.adapter)
+                    return _orm_annotate(self.property._optimized_compare(None, adapt_source=self.adapter))
             elif self.property.uselist:
                 raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
             else:
-                return self.property._optimized_compare(other, adapt_source=self.adapter)
+                return _orm_annotate(self.property._optimized_compare(other, adapt_source=self.adapter))
 
         def _criterion_exists(self, criterion=None, **kwargs):
             if getattr(self, '_of_type', None):
@@ -889,29 +889,45 @@ class RelationProperty(StrategizedProperty):
                 self.direction = MANYTOONE
 
         else:
-            for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]:
-                onetomany = [c for c in self._foreign_keys if mappedtable.c.contains_column(c)]
-                manytoone = [c for c in self._foreign_keys if parenttable.c.contains_column(c)]
-
-                if not onetomany and not manytoone:
-                    raise sa_exc.ArgumentError(
-                        "Can't determine relation direction for relationship '%s' "
-                        "- foreign key columns are present in neither the "
-                        "parent nor the child's mapped tables" %(str(self)))
-                elif onetomany and manytoone:
-                    continue
-                elif onetomany:
+            foreign_keys = [f for c, f in self.synchronize_pairs]
+
+            parentcols = util.column_set(self.parent.mapped_table.c)
+            targetcols = util.column_set(self.mapper.mapped_table.c)
+
+            # fk collection which suggests ONETOMANY.
+            onetomany_fk = targetcols.intersection(foreign_keys)
+
+            # fk collection which suggests MANYTOONE.
+            manytoone_fk = parentcols.intersection(foreign_keys)
+            
+            if not onetomany_fk and not manytoone_fk:
+                raise sa_exc.ArgumentError(
+                    "Can't determine relation direction for relationship '%s' "
+                    "- foreign key columns are present in neither the "
+                    "parent nor the child's mapped tables" % self )
+
+            elif onetomany_fk and manytoone_fk: 
+                # fks on both sides.  do the same
+                # test only based on the local side.
+                referents = [c for c, f in self.synchronize_pairs]
+                onetomany_local = parentcols.intersection(referents)
+                manytoone_local = targetcols.intersection(referents)
+
+                if onetomany_local and not manytoone_local:
                     self.direction = ONETOMANY
-                    break
-                elif manytoone:
+                elif manytoone_local and not onetomany_local:
                     self.direction = MANYTOONE
-                    break
-            else:
+            elif onetomany_fk:
+                self.direction = ONETOMANY
+            elif manytoone_fk:
+                self.direction = MANYTOONE
+                
+            if not self.direction:
                 raise sa_exc.ArgumentError(
                     "Can't determine relation direction for relationship '%s' "
                     "- foreign key columns are present in both the parent and "
                     "the child's mapped tables.  Specify 'foreign_keys' "
-                    "argument." % (str(self)))
+                    "argument." % self)
         
         if self.cascade.delete_orphan and not self.single_parent and \
             (self.direction is MANYTOMANY or self.direction is MANYTOONE):
@@ -1001,7 +1017,7 @@ class RelationProperty(StrategizedProperty):
         
 
     def _refers_to_parent_table(self):
-        return self.parent.mapped_table is self.target or self.parent.mapped_table is self.target
+        return self.parent.mapped_table is self.target
 
     def _is_self_referential(self):
         return self.mapper.common_parent(self.parent)
index de7f66882ba7b66cdfdebd9a56243b2e45df2cc3..db9ce1d67612a42076eef3ae4104ba3eef14a6a4 100644 (file)
@@ -43,7 +43,7 @@ aliased = AliasedClass
 
 def _generative(*assertions):
     """Mark a method as generative."""
-    
+
     @util.decorator
     def generate(fn, *args, **kw):
         self = args[0]._clone()
@@ -127,6 +127,7 @@ class Query(object):
 
     def __mapper_loads_polymorphically_with(self, mapper, adapter):
         for m2 in mapper._with_polymorphic_mappers:
+            self._polymorphic_adapters[m2] = adapter
             for m in m2.iterate_to_root():
                 self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
 
@@ -139,12 +140,13 @@ class Query(object):
 
         if isinstance(from_obj, expression.Alias):
             self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs)
-            
+
     def _get_polymorphic_adapter(self, entity, selectable):
         self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
 
     def _reset_polymorphic_adapter(self, mapper):
         for m2 in mapper._with_polymorphic_mappers:
+            self._polymorphic_adapters.pop(m2, None)
             for m in m2.iterate_to_root():
                 self._polymorphic_adapters.pop(m.mapped_table, None)
                 self._polymorphic_adapters.pop(m.local_table, None)
@@ -282,7 +284,7 @@ class Query(object):
         if self._order_by:
             raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth)
         self.__no_criterion_condition(meth)
-    
+
     def __no_statement_condition(self, meth):
         if self._statement:
             raise sa_exc.InvalidRequestError(
@@ -317,37 +319,35 @@ class Query(object):
     @property
     def statement(self):
         """The full SELECT statement represented by this Query."""
-        
-        return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True})
 
-    @property
-    def _nested_statement(self):
-        return self.with_labels().enable_eagerloads(False).statement.correlate(None)
+        return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True})
 
     def subquery(self):
         """return the full SELECT statement represented by this Query, embedded within an Alias.
-        
+
         Eager JOIN generation within the query is disabled.
-        
-        """
 
+        """
         return self.enable_eagerloads(False).statement.alias()
 
+    def __clause_element__(self):
+        return self.enable_eagerloads(False).statement
+
     @_generative()
     def enable_eagerloads(self, value):
         """Control whether or not eager joins are rendered.
-        
-        When set to False, the returned Query will not render 
+
+        When set to False, the returned Query will not render
         eager joins regardless of eagerload() options
         or mapper-level lazy=False configurations.
-        
+
         This is used primarily when nesting the Query's
         statement into a subquery or other
         selectable.
-        
+
         """
         self._enable_eagerloads = value
-        
+
     @_generative()
     def with_labels(self):
         """Apply column labels to the return value of Query.statement.
@@ -410,7 +410,7 @@ class Query(object):
             attribute of the mapper will be used, if any.   This is useful
             for mappers that don't have polymorphic loading behavior by default,
             such as concrete table mappers.
-        
+
         """
         entity = self._generate_mapper_zero()
         entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable, discriminator=discriminator)
@@ -554,7 +554,7 @@ class Query(object):
         those being selected.
 
         """
-        fromclause = self._nested_statement
+        fromclause = self.with_labels().enable_eagerloads(False).statement.correlate(None)
         q = self._from_selectable(fromclause)
         if entities:
             q._set_entities(entities)
@@ -728,27 +728,27 @@ class Query(object):
             q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo')
 
             q3 = q1.union(q2)
-            
+
         The method accepts multiple Query objects so as to control
         the level of nesting.  A series of ``union()`` calls such as::
-        
+
             x.union(y).union(z).all()
-            
+
         will nest on each ``union()``, and produces::
-        
+
             SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y) UNION SELECT * FROM Z)
-            
+
         Whereas::
-        
+
             x.union(y, z).all()
-            
+
         produces::
 
             SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION SELECT * FROM Z)
 
         """
         return self._from_selectable(
-                    expression.union(*([self._nested_statement]+ [x._nested_statement for x in q])))
+                    expression.union(*([self]+ list(q))))
 
     def union_all(self, *q):
         """Produce a UNION ALL of this Query against one or more queries.
@@ -758,7 +758,7 @@ class Query(object):
 
         """
         return self._from_selectable(
-                    expression.union_all(*([self._nested_statement]+ [x._nested_statement for x in q]))
+                    expression.union_all(*([self]+ list(q)))
                 )
 
     def intersect(self, *q):
@@ -769,7 +769,7 @@ class Query(object):
 
         """
         return self._from_selectable(
-                    expression.intersect(*([self._nested_statement]+ [x._nested_statement for x in q]))
+                    expression.intersect(*([self]+ list(q)))
                 )
 
     def intersect_all(self, *q):
@@ -780,7 +780,7 @@ class Query(object):
 
         """
         return self._from_selectable(
-                    expression.intersect_all(*([self._nested_statement]+ [x._nested_statement for x in q]))
+                    expression.intersect_all(*([self]+ list(q)))
                 )
 
     def except_(self, *q):
@@ -791,7 +791,7 @@ class Query(object):
 
         """
         return self._from_selectable(
-                    expression.except_(*([self._nested_statement]+ [x._nested_statement for x in q]))
+                    expression.except_(*([self]+ list(q)))
                 )
 
     def except_all(self, *q):
@@ -802,7 +802,7 @@ class Query(object):
 
         """
         return self._from_selectable(
-                    expression.except_all(*([self._nested_statement]+ [x._nested_statement for x in q]))
+                    expression.except_all(*([self]+ list(q)))
                 )
 
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
@@ -887,7 +887,7 @@ class Query(object):
 
     @_generative(__no_statement_condition, __no_limit_offset)
     def __join(self, keys, outerjoin, create_aliases, from_joinpoint):
-        
+
         # copy collections that may mutate so they do not affect
         # the copied-from query.
         self.__currenttables = set(self.__currenttables)
@@ -904,7 +904,7 @@ class Query(object):
         # after the method completes,
         # the query's joinpoint will be set to this.
         right_entity = None
-        
+
         for arg1 in util.to_list(keys):
             aliased_entity = False
             alias_criterion = False
@@ -970,7 +970,7 @@ class Query(object):
                     if ent.corresponds_to(left_entity):
                         clause = ent.selectable
                         break
-                    
+
             if not clause:
                 if isinstance(onclause, interfaces.PropComparator):
                     clause = onclause.__clause_element__()
@@ -985,14 +985,14 @@ class Query(object):
                 onclause = prop
 
             # start looking at the right side of the join
-            
+
             mp, right_selectable, is_aliased_class = _entity_info(right_entity)
-            
+
             if mp is not None and right_mapper is not None and not mp.common_parent(right_mapper):
                 raise sa_exc.InvalidRequestError(
                     "Join target %s does not correspond to the right side of join condition %s" % (right_entity, onclause)
                 )
-            
+
             if not right_mapper and mp:
                 right_mapper = mp
 
@@ -1004,7 +1004,7 @@ class Query(object):
 
                     if not right_selectable.is_derived_from(right_mapper.mapped_table):
                         raise sa_exc.InvalidRequestError(
-                            "Selectable '%s' is not derived from '%s'" % 
+                            "Selectable '%s' is not derived from '%s'" %
                             (right_selectable.description, right_mapper.mapped_table.description))
 
                     if not isinstance(right_selectable, expression.Alias):
@@ -1026,7 +1026,7 @@ class Query(object):
                     # for joins across plain relation()s, try not to specify the
                     # same joins twice.  the __currenttables collection tracks
                     # what plain mapped tables we've joined to already.
-                    
+
                     if prop.table in self.__currenttables:
                         if prop.secondary is not None and prop.secondary not in self.__currenttables:
                             # TODO: this check is not strong enough for different paths to the same endpoint which
@@ -1039,7 +1039,7 @@ class Query(object):
                     if prop.secondary:
                         self.__currenttables.add(prop.secondary)
                     self.__currenttables.add(prop.table)
-                    
+
                     if of_type:
                         right_entity = of_type
                     else:
@@ -1057,8 +1057,8 @@ class Query(object):
                     onclause = right_adapter.traverse(onclause)
                 onclause = self._adapt_clause(onclause, False, True)
 
-            # determine if we want _ORMJoin to alias the onclause 
-            # to the given left side.  This is used if we're joining against a 
+            # determine if we want _ORMJoin to alias the onclause
+            # to the given left side.  This is used if we're joining against a
             # select_from() selectable, from_self() call, or the onclause
             # has been resolved into a MapperProperty.  Otherwise we assume
             # the onclause itself contains more specific information on how to
@@ -1066,10 +1066,10 @@ class Query(object):
             join_to_left = not is_aliased_class or \
                             onclause is prop or \
                             clause is self._from_obj and self._from_obj_alias
-            
-            # create the join                
+
+            # create the join
             clause = orm_join(clause, right_entity, onclause, isouter=outerjoin, join_to_left=join_to_left)
-            
+
             # set up state for the query as a whole
             if alias_criterion:
                 # adapt filter() calls based on our right side adaptation
@@ -1080,14 +1080,14 @@ class Query(object):
                 # and adapt when it renders columns and fetches them from results
                 if aliased_entity:
                     self.__mapper_loads_polymorphically_with(
-                                        right_mapper, 
+                                        right_mapper,
                                         ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns)
                                     )
-        
-        # loop finished.  we're selecting from 
+
+        # loop finished.  we're selecting from
         # our final clause now
         self._from_obj = clause
-        
+
         # future joins with from_joinpoint=True join from our established right_entity.
         self._joinpoint = right_entity
 
@@ -1126,13 +1126,13 @@ class Query(object):
 
             if isinstance(stop, int) and isinstance(start, int) and stop - start <= 0:
                 return []
-            
+
             # perhaps we should execute a count() here so that we
             # can still use LIMIT/OFFSET ?
             elif (isinstance(start, int) and start < 0) \
                 or (isinstance(stop, int) and stop < 0):
                 return list(self)[item]
-                
+
             res = self.slice(start, stop)
             if step is not None:
                 return list(res)[None:None:item.step]
@@ -1204,7 +1204,7 @@ class Query(object):
 
         if not isinstance(statement, (expression._TextClause, expression._SelectBaseMixin)):
             raise sa_exc.ArgumentError("from_statement accepts text(), select(), and union() objects only.")
-        
+
         self._statement = statement
 
     def first(self):
@@ -1439,18 +1439,18 @@ class Query(object):
 
     def count(self):
         """Apply this query's criterion to a SELECT COUNT statement.
-        
+
         If column expressions or LIMIT/OFFSET/DISTINCT are present,
-        the query "SELECT count(1) FROM (SELECT ...)" is issued, 
+        the query "SELECT count(1) FROM (SELECT ...)" is issued,
         so that the result matches the total number of rows
         this query would return.  For mapped entities,
-        the primary key columns of each is written to the 
+        the primary key columns of each is written to the
         columns clause of the nested SELECT statement.
-        
+
         For a Query which is only against mapped entities,
-        a simpler "SELECT count(1) FROM table1, table2, ... 
-        WHERE criterion" is issued.  
-        
+        a simpler "SELECT count(1) FROM table1, table2, ...
+        WHERE criterion" is issued.
+
         """
         should_nest = [self._should_nest_selectable]
         def ent_cols(ent):
@@ -1459,8 +1459,8 @@ class Query(object):
             else:
                 should_nest[0] = True
                 return [ent.column]
-                
-        return self._col_aggregate(sql.literal_column('1'), sql.func.count, 
+
+        return self._col_aggregate(sql.literal_column('1'), sql.func.count,
             nested_cols=chain(*[ent_cols(ent) for ent in self._entities]),
             should_nest = should_nest[0]
         )
@@ -1498,9 +1498,9 @@ class Query(object):
     def delete(self, synchronize_session='evaluate'):
         """Perform a bulk delete query.
 
-        Deletes rows matched by this query from the database. 
-        
-        :param synchronize_session: chooses the strategy for the removal of matched 
+        Deletes rows matched by this query from the database.
+
+        :param synchronize_session: chooses the strategy for the removal of matched
             objects from the session. Valid values are:
 
             False
@@ -1528,10 +1528,10 @@ class Query(object):
         The method does *not* offer in-Python cascading of relations - it is assumed that
         ON DELETE CASCADE is configured for any foreign key references which require it.
         The Session needs to be expired (occurs automatically after commit(), or call expire_all())
-        in order for the state of dependent objects subject to delete or delete-orphan cascade to be 
+        in order for the state of dependent objects subject to delete or delete-orphan cascade to be
         correctly represented.
-        
-        Also, the ``before_delete()`` and ``after_delete()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` 
+
+        Also, the ``before_delete()`` and ``after_delete()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension`
         methods are not called from this method.  For a delete hook here, use the
         ``after_bulk_delete()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` method.
 
@@ -1591,11 +1591,11 @@ class Query(object):
     def update(self, values, synchronize_session='evaluate'):
         """Perform a bulk update query.
 
-        Updates rows matched by this query in the database. 
-        
+        Updates rows matched by this query in the database.
+
         :param values: a dictionary with attributes names as keys and literal values or sql expressions
-            as values. 
-        
+            as values.
+
         :param synchronize_session: chooses the strategy to update the
             attributes on objects in the session. Valid values are:
 
@@ -1621,10 +1621,14 @@ class Query(object):
 
         The method does *not* offer in-Python cascading of relations - it is assumed that
         ON UPDATE CASCADE is configured for any foreign key references which require it.
-        
-        Also, the ``before_update()`` and ``after_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension` 
+
+        The Session needs to be expired (occurs automatically after commit(), or call expire_all())
+        in order for the state of dependent objects subject foreign key cascade to be
+        correctly represented.
+
+        Also, the ``before_update()`` and ``after_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension`
         methods are not called from this method.  For an update hook here, use the
-        ``after_bulk_update()`` :class:`~sqlalchemy.orm.interfaces.MapperExtension`  method.
+        ``after_bulk_update()`` :class:`~sqlalchemy.orm.interfaces.SessionExtension`  method.
 
         """
 
@@ -1694,7 +1698,7 @@ class Query(object):
 
         for ext in session.extensions:
             ext.after_bulk_update(session, self, context, result)
-            
+
         return result.rowcount
 
     def _compile_context(self, labels=True):
@@ -1894,10 +1898,7 @@ class _MapperEntity(_QueryEntity):
 
         adapter = None
         if not self.is_aliased_class and query._polymorphic_adapters:
-            for mapper in self.mapper.iterate_to_root():
-                adapter = query._polymorphic_adapters.get(mapper.mapped_table, None)
-                if adapter:
-                    break
+            adapter = query._polymorphic_adapters.get(self.mapper, None)
 
         if not adapter and self.adapter:
             adapter = self.adapter
@@ -1959,7 +1960,7 @@ class _MapperEntity(_QueryEntity):
             # apply adaptation to the mapper's order_by if needed.
             if adapter:
                 context.order_by = adapter.adapt_list(util.to_list(context.order_by))
-                    
+
         for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic):
             if query._only_load_props and value.key not in query._only_load_props:
                 continue
@@ -1971,14 +1972,14 @@ class _MapperEntity(_QueryEntity):
                 only_load_props=query._only_load_props,
                 column_collection=context.primary_columns
             )
-        
+
         if self._polymorphic_discriminator:
             if adapter:
                 pd = adapter.columns[self._polymorphic_discriminator]
             else:
                 pd = self._polymorphic_discriminator
             context.primary_columns.append(pd)
-            
+
     def __str__(self):
         return str(self.mapper)
 
@@ -1995,20 +1996,25 @@ class _ColumnEntity(_QueryEntity):
             column = column.__clause_element__()
         else:
             self._result_label = getattr(column, 'key', None)
-        
+
         if not isinstance(column, expression.ColumnElement) and hasattr(column, '_select_iterable'):
             for c in column._select_iterable:
                 if c is column:
                     break
                 _ColumnEntity(query, c)
-            
+
             if c is not column:
                 return
 
         if not isinstance(column, sql.ColumnElement):
             raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column)
 
-        if not hasattr(column, '_label'):
+        # if the Column is unnamed, give it a
+        # label() so that mutable column expressions
+        # can be located in the result even
+        # if the expression's identity has been changed
+        # due to adaption
+        if not column._label:
             column = column.label(None)
 
         query._entities.append(self)
index c756045a1e8b26b928ebe3bdc53010023fb2b499..61c58b2499a45f369bceae28c3d1ac5662b5a4b0 100644 (file)
@@ -267,7 +267,7 @@ class UOWTransaction(object):
         for elem in self.elements:
             if elem.isdelete:
                 self.session._remove_newly_deleted(elem.state)
-            else:
+            elif not elem.listonly:
                 self.session._register_newly_persistent(elem.state)
 
     def _sort_dependencies(self):
index c8637290177ef598bc320cf0a5b843f1f4017e76..1cca9e00b863602f2463ef308586e5a513262322 100644 (file)
@@ -258,13 +258,20 @@ class ORMAdapter(sql_util.ColumnAdapter):
     
     """
     def __init__(self, entity, equivalents=None, chain_to=None):
-        mapper, selectable, is_aliased_class = _entity_info(entity)
+        self.mapper, selectable, is_aliased_class = _entity_info(entity)
         if is_aliased_class:
             self.aliased_class = entity
         else:
             self.aliased_class = None
         sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to)
 
+    def replace(self, elem):
+        entity = elem._annotations.get('parentmapper', None)
+        if not entity or entity.isa(self.mapper):
+            return sql_util.ColumnAdapter.replace(self, elem)
+        else:
+            return None
+
 class AliasedClass(object):
     """Represents an 'alias'ed form of a mapped class for usage with Query.
     
@@ -303,7 +310,7 @@ class AliasedClass(object):
         self.__name__ = 'AliasedClass_' + str(self.__target)
         
     def __adapt_element(self, elem):
-        return self.__adapter.traverse(elem)._annotate({'parententity': self})
+        return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
         
     def __adapt_prop(self, prop):
         existing = getattr(self.__target, prop.key)
index 4608024bbfba6de7b784b363139ee441cecd9732..787827e8be063e6f7ff2dbd470bdc327c38e4b47 100644 (file)
@@ -277,8 +277,9 @@ class SQLCompiler(engine.Compiled):
             else:
                 schema_prefix = ''
             tablename = column.table.name
-            if isinstance(tablename, sql._generated_label):
-                tablename = tablename % self.anon_map
+            tablename = isinstance(tablename, sql._generated_label) and \
+                            self._truncated_identifier("alias", tablename) or tablename
+            
             return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name
 
     def escape_literal_column(self, text):
@@ -330,8 +331,16 @@ class SQLCompiler(engine.Compiled):
         return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
                         if s is not None)
 
-    def visit_calculatedclause(self, clause, **kwargs):
-        return self.process(clause.clause_expr)
+    def visit_case(self, clause, **kwargs):
+        x = "CASE "
+        if clause.value:
+            x += self.process(clause.value) + " "
+        for cond, result in clause.whens:
+            x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " "
+        if clause.else_:
+            x += "ELSE " + self.process(clause.else_) + " "
+        x += "END"
+        return x
 
     def visit_cast(self, cast, **kwargs):
         return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
@@ -444,8 +453,11 @@ class SQLCompiler(engine.Compiled):
 
     def visit_alias(self, alias, asfrom=False, **kwargs):
         if asfrom:
+            alias_name = isinstance(alias.name, sql._generated_label) and \
+                            self._truncated_identifier("alias", alias.name) or alias.name
+            
             return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \
-                    self.preparer.format_alias(alias, alias.name % self.anon_map)
+                    self.preparer.format_alias(alias, alias_name)
         else:
             return self.process(alias.original, **kwargs)
 
index 34d50c9c591ad5a8acb5cc1a689b3f891417d908..cfc7f407eb72d5f1321a2f78026ebe825679418d 100644 (file)
@@ -461,26 +461,8 @@ def case(whens, value=None, else_=None):
           })
 
     """
-    try:
-        whens = util.dictlike_iteritems(whens)
-    except TypeError:
-        pass
-
-    if value:
-        crit_filter = _literal_as_binds
-    else:
-        crit_filter = _no_literals
-
-    whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None)
-                for (c,r) in whens]
-    if else_ is not None:
-        whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None))
-    if whenlist:
-        type = list(whenlist[-1])[-1].type
-    else:
-        type = None
-    cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END'])
-    return cc
+    
+    return _Case(whens, value=value, else_=else_)
 
 def cast(clause, totype, **kwargs):
     """Return a ``CAST`` function.
@@ -508,9 +490,10 @@ def collate(expression, collation):
     """Return the clause ``expression COLLATE collation``."""
 
     expr = _literal_as_binds(expression)
-    return _CalculatedClause(
-        expr, expr, _literal_as_text(collation),
-        operator=operators.collate, group=False)
+    return _BinaryExpression(
+        expr, 
+        _literal_as_text(collation), 
+        operators.collate, type_=expr.type)
 
 def exists(*args, **kwargs):
     """Return an ``EXISTS`` clause as applied to a :class:`~sqlalchemy.sql.expression.Select` object.
@@ -922,6 +905,12 @@ def _literal_as_text(element):
     else:
         return element
 
+def _clause_element_as_expr(element):
+    if hasattr(element, '__clause_element__'):
+        return element.__clause_element__()
+    else:
+        return element
+        
 def _literal_as_column(element):
     if hasattr(element, '__clause_element__'):
         return element.__clause_element__()
@@ -958,14 +947,6 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False):
                 % (column, getattr(column, 'table', None), fromclause.description))
     return c
 
-def _selectable(element):
-    if hasattr(element, '__selectable__'):
-        return element.__selectable__()
-    elif isinstance(element, Selectable):
-        return element
-    else:
-        raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element)
-
 def is_column(col):
     """True if ``col`` is an instance of ``ColumnElement``."""
     return isinstance(col, ColumnElement)
@@ -1148,18 +1129,16 @@ class ClauseElement(Visitable):
 
         The return value is a :class:`~sqlalchemy.engine.Compiled` object.
         Calling `str()` or `unicode()` on the returned value will yield
-        a string representation of the result.   The ``Compiled``
+        a string representation of the result.   The :class:`~sqlalchemy.engine.Compiled`
         object also can return a dictionary of bind parameter names and
         values using the `params` accessor.
 
-        bind
-          An ``Engine`` or ``Connection`` from which a
+        :param bind: An ``Engine`` or ``Connection`` from which a
           ``Compiled`` will be acquired.  This argument
           takes precedence over this ``ClauseElement``'s
           bound engine, if any.
 
-        dialect
-          A ``Dialect`` instance frmo which a ``Compiled``
+        :param dialect: A ``Dialect`` instance frmo which a ``Compiled``
           will be acquired.  This argument takes precedence
           over the `bind` argument as well as this
           ``ClauseElement``'s bound engine, if any.
@@ -1433,6 +1412,8 @@ class _CompareMixin(ColumnOperators):
         return self._in_impl(operators.in_op, operators.notin_op, other)
 
     def _in_impl(self, op, negate_op, seq_or_selectable):
+        seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
+            
         if isinstance(seq_or_selectable, _ScalarSelect):
             return self.__compare( op, seq_or_selectable, negate=negate_op)
 
@@ -1450,7 +1431,8 @@ class _CompareMixin(ColumnOperators):
         for o in seq_or_selectable:
             if not _is_literal(o):
                 if not isinstance( o, _CompareMixin):
-                    raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
+                    raise exc.InvalidRequestError( 
+                        "in() function accepts either a list of non-selectable values, or a selectable: %r" % o)
             else:
                 o = self._bind_param(o)
             args.append(o)
@@ -1534,9 +1516,7 @@ class _CompareMixin(ColumnOperators):
     def collate(self, collation):
         """Produce a COLLATE clause, i.e. ``<column> COLLATE utf8_bin``"""
 
-        return _CalculatedClause(
-           None, self, _literal_as_text(collation),
-            operator=operators.collate, group=False)
+        return collate(self, collation)
 
     def op(self, operator):
         """produce a generic operator function.
@@ -1607,7 +1587,8 @@ class ColumnElement(ClauseElement, _CompareMixin):
     primary_key = False
     foreign_keys = []
     quote = None
-
+    _label = None
+    
     @property
     def _select_iterable(self):
         return (self, )
@@ -1830,6 +1811,10 @@ class FromClause(Selectable):
         return ClauseAdapter(alias).traverse(self)
 
     def correspond_on_equivalents(self, column, equivalents):
+        """Return corresponding_column for the given column, or if None
+        search for a match in the given dictionary.
+        
+        """
         col = self.corresponding_column(column, require_embedded=True)
         if col is None and col in equivalents:
             for equiv in equivalents[col]:
@@ -1843,11 +1828,9 @@ class FromClause(Selectable):
         object from this ``Selectable`` which corresponds to that
         original ``Column`` via a common anscestor column.
 
-        column
-          the target ``ColumnElement`` to be matched
+        :param column: the target ``ColumnElement`` to be matched
 
-        require_embedded
-          only return corresponding columns for the given
+        :param require_embedded: only return corresponding columns for the given
           ``ColumnElement``, if the given ``ColumnElement`` is
           actually present within a sub-element of this
           ``FromClause``.  Normally the column will match if it merely
@@ -2216,73 +2199,55 @@ class BooleanClauseList(ClauseList, ColumnElement):
         return (self, )
 
 
-class _CalculatedClause(ColumnElement):
-    """Describe a calculated SQL expression that has a type, like ``CASE``.
-
-    Extends ``ColumnElement`` to provide column-level comparison
-    operators.
-
-    """
+class _Case(ColumnElement):
+    __visit_name__ = 'case'
 
-    __visit_name__ = 'calculatedclause'
+    def __init__(self, whens, value=None, else_=None):
+        try:
+            whens = util.dictlike_iteritems(whens)
+        except TypeError:
+            pass
 
-    def __init__(self, name, *clauses, **kwargs):
-        self.name = name
-        self.type = sqltypes.to_instance(kwargs.get('type_', None))
-        self._bind = kwargs.get('bind', None)
-        self.group = kwargs.pop('group', True)
-        clauses = ClauseList(
-            operator=kwargs.get('operator', None),
-            group_contents=kwargs.get('group_contents', True),
-            *clauses)
-        if self.group:
-            self.clause_expr = clauses.self_group()
+        if value:
+            whenlist = [(_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens]
         else:
-            self.clause_expr = clauses
-
-    @property
-    def key(self):
-        return self.name or '_calc_'
+            whenlist = [(_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens]
+            
+        if whenlist:
+            type_ = list(whenlist[-1])[-1].type
+        else:
+            type_ = None
+            
+        self.value = value
+        self.type = type_
+        self.whens = whenlist
+        if else_ is not None:
+            self.else_ = _literal_as_binds(else_)
+        else:
+            self.else_ = None
 
     def _copy_internals(self, clone=_clone):
-        self.clause_expr = clone(self.clause_expr)
-
-    @property
-    def clauses(self):
-        if isinstance(self.clause_expr, _Grouping):
-            return self.clause_expr.element
-        else:
-            return self.clause_expr
+        if self.value:
+            self.value = clone(self.value)
+        self.whens = [(clone(x), clone(y)) for x, y in self.whens]
+        if self.else_:
+            self.else_ = clone(self.else_)
 
     def get_children(self, **kwargs):
-        return self.clause_expr,
+        if self.value:
+            yield self.value
+        for x, y in self.whens:
+            yield x
+            yield y
+        if self.else_:
+            yield self.else_ 
 
     @property
     def _from_objects(self):
-        return self.clauses._from_objects
-
-    def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, type_=self.type, unique=True)
+        return itertools.chain(*[x._from_objects for x in self.get_children()])
 
-    def select(self):
-        return select([self])
-
-    def scalar(self):
-        return select([self]).execute().scalar()
-
-    def execute(self):
-        return select([self]).execute()
-
-    def _compare_type(self, obj):
-        return self.type
-
-class Function(_CalculatedClause, FromClause):
-    """Describe a SQL function.
-
-    Extends ``_CalculatedClause``, turn the *clauselist* into function
-    arguments, also adds a `packagenames` argument.
-
-    """
+class Function(ColumnElement, FromClause):
+    """Describe a SQL function."""
 
     __visit_name__ = 'function'
 
@@ -2302,12 +2267,36 @@ class Function(_CalculatedClause, FromClause):
     def columns(self):
         return [self]
 
+    @util.memoized_property
+    def clauses(self):
+        return self.clause_expr.element
+        
+    @property
+    def _from_objects(self):
+        return self.clauses._from_objects
+
+    def get_children(self, **kwargs):
+        return self.clause_expr, 
+
     def _copy_internals(self, clone=_clone):
-        _CalculatedClause._copy_internals(self, clone=clone)
+        self.clause_expr = clone(self.clause_expr)
         self._reset_exported()
+        util.reset_memoized(self, 'clauses')
+        
+    def _bind_param(self, obj):
+        return _BindParamClause(self.name, obj, type_=self.type, unique=True)
 
-    def get_children(self, **kwargs):
-        return _CalculatedClause.get_children(self, **kwargs)
+    def select(self):
+        return select([self])
+
+    def scalar(self):
+        return select([self]).execute().scalar()
+
+    def execute(self):
+        return select([self]).execute()
+
+    def _compare_type(self, obj):
+        return self.type
 
 
 class _Cast(ColumnElement):
@@ -2493,8 +2482,8 @@ class Join(FromClause):
     __visit_name__ = 'join'
 
     def __init__(self, left, right, onclause=None, isouter=False):
-        self.left = _selectable(left)
-        self.right = _selectable(right).self_group()
+        self.left = _literal_as_text(left)
+        self.right = _literal_as_text(right).self_group()
 
         if onclause is None:
             self.onclause = self._match_primaries(self.left, self.right)
@@ -2843,9 +2832,12 @@ class ColumnClause(_Immutable, ColumnElement):
 
         elif self.table and self.table.named_with_column:
             if getattr(self.table, 'schema', None):
-                label = self.table.schema + "_" + _escape_for_generated(self.table.name) + "_" + _escape_for_generated(self.name)
+                label = self.table.schema + "_" + \
+                            _escape_for_generated(self.table.name) + "_" + \
+                            _escape_for_generated(self.name)
             else:
-                label = _escape_for_generated(self.table.name) + "_" + _escape_for_generated(self.name)
+                label = _escape_for_generated(self.table.name) + "_" + \
+                            _escape_for_generated(self.name)
 
             if label in self.table.c:
                 # TODO: coverage does not seem to be present for this
@@ -3133,6 +3125,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
 
         # some DBs do not like ORDER BY in the inner queries of a UNION, etc.
         for n, s in enumerate(selects):
+            s = _clause_element_as_expr(s)
+            
             if not numcols:
                 numcols = len(s.c)
             elif len(s.c) != numcols:
@@ -3398,9 +3392,7 @@ class Select(_SelectBaseMixin, FromClause):
         """return a new select() construct with the given FROM expression applied to its list of
         FROM objects."""
 
-        if _is_literal(fromclause):
-            fromclause = _TextClause(fromclause)
-
+        fromclause = _literal_as_text(fromclause)
         self._froms = self._froms.union([fromclause])
 
     @_generative
index 1bcc6d864f9b03bad9806a0d0330020e3fb0b838..c6cb938d44c4a6f473f6c50fa7c2e4aa8fbc0796 100644 (file)
@@ -13,18 +13,13 @@ class _GenericMeta(VisitableType):
 class GenericFunction(Function):
     __metaclass__ = _GenericMeta
 
-    def __init__(self, type_=None, group=True, args=(), **kwargs):
+    def __init__(self, type_=None, args=(), **kwargs):
         self.packagenames = []
         self.name = self.__class__.__name__
         self._bind = kwargs.get('bind', None)
-        if group:
-            self.clause_expr = ClauseList(
+        self.clause_expr = ClauseList(
                 operator=operators.comma_op,
                 group_contents=True, *args).self_group()
-        else:
-            self.clause_expr = ClauseList(
-                operator=operators.comma_op,
-                group_contents=True, *args)
         self.type = sqltypes.to_instance(
             type_ or getattr(self, '__return_type__', None))
 
index cd1e48cafed939f38278c245d1e0361692cf1446..879f0f3e517c0ecb2970e4684fef8d56a5df9979 100644 (file)
@@ -122,7 +122,7 @@ _PRECEDENCE = {
     and_: 3,
     or_: 2,
     comma_op: -1,
-    collate: -2,
+    collate: 7,
     as_: -1,
     exists: 0,
     _smallest: -1000,
index e1619cbc072c5ed911ca3a24c1e3dd820988e533..aeafb76475a11362d8e4f97b980bd6658e3c266b 100644 (file)
@@ -1383,10 +1383,7 @@ class memoized_instancemethod(object):
         return oneshot
 
 def reset_memoized(instance, name):
-    try:
-        del instance.__dict__[name]
-    except KeyError:
-        pass
+    instance.__dict__.pop(name, None)
 
 class WeakIdentityMapping(weakref.WeakKeyDictionary):
     """A WeakKeyDictionary with an object identity index.
index ace572fd541061c68df91ca3e0102688ba433254..3ce8a9220bcb505c7bbfc68bda2b7720c48cb3bf 100755 (executable)
@@ -126,6 +126,61 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(func.current_date(), "GETDATE()")
         self.assert_compile(func.length(3), "LEN(:length_1)")
 
+
+class IdentityInsertTest(TestBase, AssertsCompiledSQL):
+    __only_on__ = 'mssql'
+    __dialect__ = mssql.MSSQLDialect()
+
+    def setUpAll(self):
+        global metadata, cattable
+        metadata = MetaData(testing.db)
+
+        cattable = Table('cattable', metadata,
+            Column('id', Integer),
+            Column('description', String(50)),
+            PrimaryKeyConstraint('id', name='PK_cattable'),
+        )
+
+    def setUp(self):
+        metadata.create_all()
+
+    def tearDown(self):
+        metadata.drop_all()
+
+    def test_compiled(self):
+        self.assert_compile(cattable.insert().values(id=9, description='Python'), "INSERT INTO cattable (id, description) VALUES (:id, :description)")
+
+    def test_execute(self):
+        cattable.insert().values(id=9, description='Python').execute()
+
+        cats = cattable.select().order_by(cattable.c.id).execute()
+        self.assertEqual([(9, 'Python')], list(cats))
+
+        result = cattable.insert().values(description='PHP').execute()
+        self.assertEqual([10], result.last_inserted_ids())
+        lastcat = cattable.select().order_by(desc(cattable.c.id)).execute()
+        self.assertEqual((10, 'PHP'), lastcat.fetchone())
+
+    def test_executemany(self):
+        cattable.insert().execute([
+            {'id': 89, 'description': 'Python'},
+            {'id': 8, 'description': 'Ruby'},
+            {'id': 3, 'description': 'Perl'},
+            {'id': 1, 'description': 'Java'},
+        ])
+
+        cats = cattable.select().order_by(cattable.c.id).execute()
+        self.assertEqual([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats))
+
+        cattable.insert().execute([
+            {'description': 'PHP'},
+            {'description': 'Smalltalk'},
+        ])
+
+        lastcats = cattable.select().order_by(desc(cattable.c.id)).limit(2).execute()
+        self.assertEqual([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
+
+
 class ReflectionTest(TestBase):
     __only_on__ = 'mssql'
 
index e62ef93eb302fe7f79bd72cf13880a62dbaabdfd..dfe2dfd18298b3db90324786db2bd4ccbd0c9d47 100644 (file)
@@ -639,15 +639,23 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         m1 = MetaData(testing.db)
         t1 = Table('party', m1,
             Column('id', String(10), nullable=False),
-            Column('name', String(20), index=True)
+            Column('name', String(20), index=True), 
+            Column('aname', String(20))
             )
         m1.create_all()
+        
         testing.db.execute("""
           create index idx1 on party ((id || name))
         """, None) 
         testing.db.execute("""
           create unique index idx2 on party (id) where name = 'test'
         """, None)
+        
+        testing.db.execute("""
+            create index idx3 on party using btree
+                (lower(name::text), lower(aname::text))
+        """)
+        
         try:
             m2 = MetaData(testing.db)
 
@@ -663,6 +671,7 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             # Make sure indexes are in the order we expect them in
             tmp = [(idx.name, idx) for idx in t2.indexes]
             tmp.sort()
+            
             r1, r2 = [idx[1] for idx in tmp]
 
             assert r1.name == 'idx2'
index 3345a5d8cf645455477d8675695864afdeb3fca0..746dc0e52f768614e76429624c1518b2044ae00f 100644 (file)
@@ -1176,5 +1176,116 @@ class CollectionAssignmentOrphanTest(_base.MappedTest):
         eq_(sess.query(A).get(a1.id),
             A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]))
 
+
+class PartialFlushTest(_base.MappedTest):
+    """test cascade behavior as it relates to object lists passed to flush().
+    
+    """
+    def define_tables(self, metadata):
+        Table("base", metadata,
+            Column("id", Integer, primary_key=True),
+            Column("descr", String(50))
+        )
+
+        Table("noninh_child", metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('base_id', Integer, ForeignKey('base.id'))
+        )
+
+        Table("parent", metadata,
+            Column("id", Integer, ForeignKey("base.id"), primary_key=True)
+        )
+        Table("inh_child", metadata,
+            Column("id", Integer, ForeignKey("base.id"), primary_key=True),
+            Column("parent_id", Integer, ForeignKey("parent.id"))
+        )
+
+
+    @testing.resolve_artifact_names
+    def test_o2m_m2o(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Child(_base.ComparableEntity):
+            pass
+
+        mapper(Base, base, properties={
+            'children':relation(Child, backref='parent')
+        })
+        mapper(Child, noninh_child)
+
+        sess = create_session()
+
+        c1, c2 = Child(), Child()
+        b1 = Base(descr='b1', children=[c1, c2])
+        sess.add(b1)
+
+        assert c1 in sess.new
+        assert c2 in sess.new
+        sess.flush([b1])
+
+        # c1, c2 get cascaded into the session on o2m.
+        # not sure if this is how I like this 
+        # to work but that's how it works for now.
+        assert c1 in sess and c1 not in sess.new
+        assert c2 in sess and c2 not in sess.new
+        assert b1 in sess and b1 not in sess.new
+
+        sess = create_session()
+        c1, c2 = Child(), Child()
+        b1 = Base(descr='b1', children=[c1, c2])
+        sess.add(b1)
+        sess.flush([c1])
+        # m2o, otoh, doesn't cascade up the other way.
+        assert c1 in sess and c1 not in sess.new
+        assert c2 in sess and c2 in sess.new
+        assert b1 in sess and b1 in sess.new
+
+        sess = create_session()
+        c1, c2 = Child(), Child()
+        b1 = Base(descr='b1', children=[c1, c2])
+        sess.add(b1)
+        sess.flush([c1, c2])
+        # m2o, otoh, doesn't cascade up the other way.
+        assert c1 in sess and c1 not in sess.new
+        assert c2 in sess and c2 not in sess.new
+        assert b1 in sess and b1 in sess.new
+
+    @testing.resolve_artifact_names
+    def test_circular_sort(self):
+        """test ticket 1306"""
+        
+        class Base(_base.ComparableEntity):
+            pass
+        class Parent(Base):
+            pass
+        class Child(Base):
+            pass
+
+        mapper(Base,base)
+
+        mapper(Child, inh_child,
+            inherits=Base,
+            properties={'parent': relation(
+                Parent,
+                backref='children', 
+                primaryjoin=inh_child.c.parent_id == parent.c.id
+            )}
+        )
+
+
+        mapper(Parent,parent, inherits=Base)
+
+        sess = create_session()
+        p1 = Parent()
+
+        c1, c2, c3 = Child(), Child(), Child()
+        p1.children = [c1, c2, c3]
+        sess.add(p1)
+        
+        sess.flush([c1])
+        assert p1 in sess.new
+        assert c1 not in sess.new
+        assert c2 in sess.new
+        
 if __name__ == "__main__":
     testenv.main()
index 4e8771347e209c403905b583e630e665f547cda7..c11fb69dfec3da1d5e595513b88859e956415bcb 100644 (file)
@@ -747,6 +747,51 @@ class PolymorphicExpireTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 2)
         self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
 
+class ExpiredPendingTest(_fixtures.FixtureTest):
+    run_define_tables = 'once'
+    run_setup_classes = 'once'
+    run_setup_mappers = None
+    run_inserts = None
+    
+    @testing.resolve_artifact_names
+    def test_expired_pending(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        a1 = Address(email_address='a1')
+        sess.add(a1)
+        sess.flush()
+        
+        u1 = User(name='u1')
+        a1.user = u1
+        sess.flush()
+
+        # expire 'addresses'.  backrefs
+        # which attach to u1 will expect to be "pending"
+        sess.expire(u1, ['addresses'])
+
+        # attach an Address.  now its "pending" 
+        # in user.addresses
+        a2 = Address(email_address='a2')
+        a2.user = u1
+
+        # expire u1.addresses again.  this expires
+        # "pending" as well.
+        sess.expire(u1, ['addresses'])
+        
+        # insert a new row
+        sess.execute(addresses.insert(), dict(email_address='a3', user_id=u1.id))
+        
+        # only two addresses pulled from the DB, no "pending"
+        assert len(u1.addresses) == 2
+        
+        sess.flush()
+        sess.expire_all()
+        assert len(u1.addresses) == 3
+    
 
 class RefreshTest(_fixtures.FixtureTest):
 
@@ -783,9 +828,6 @@ class RefreshTest(_fixtures.FixtureTest):
         s.expire(u)
 
         # get the attribute, it refreshes
-        print "OK------"
-#        print u.__dict__
-#        print u._state.callables
         assert u.name == 'jack'
         assert id(a) not in [id(x) for x in u.addresses]
 
index fe948931b6acfb12998906a00bf1e4eedcf45ebe..ca789f8338ea23237bfc595833bca29a64f201aa 100644 (file)
@@ -9,6 +9,8 @@ from sqlalchemy.orm import *
 from sqlalchemy import exc as sa_exc
 from testlib import *
 from testlib import fixtures
+from orm import _base
+from testlib.testing import eq_
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.engine import default
 
@@ -748,7 +750,7 @@ class SelfReferentialTestJoinedToBase(ORMTest):
             sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), 
             Engineer(name='dilbert'))
 
-class SelfReferentialTestJoinedToJoined(ORMTest):
+class SelfReferentialJ2JTest(ORMTest):
     keep_mappers = True
 
     def define_tables(self, metadata):
@@ -773,7 +775,7 @@ class SelfReferentialTestJoinedToJoined(ORMTest):
         
         mapper(Engineer, engineers, inherits=Person, 
           polymorphic_identity='engineer', properties={
-          'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id)
+          'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id, backref='engineers')
         })
 
     def test_has(self):
@@ -800,6 +802,62 @@ class SelfReferentialTestJoinedToJoined(ORMTest):
         self.assertEquals(
             sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), 
             Engineer(name='dilbert'))
+    
+    def test_filter_aliasing(self):
+        m1 = Manager(name='dogbert')
+        m2 = Manager(name='foo')
+        e1 = Engineer(name='wally', primary_language='java', reports_to=m1)
+        e2 = Engineer(name='dilbert', primary_language='c++', reports_to=m2)
+        e3 = Engineer(name='etc', primary_language='c++')
+        sess = create_session()
+        sess.add_all([m1, m2, e1, e2, e3])
+        sess.flush()
+        sess.expunge_all()
+
+        # filter aliasing applied to Engineer doesn't whack Manager
+        self.assertEquals(
+            sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(),
+            [m1]
+        )
+
+        self.assertEquals(
+            sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(),
+            [m2]
+        )
+
+        self.assertEquals(
+            sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(),
+            [
+                (m2, e2),
+                (m1, e1),
+            ]
+        )
+        
+    def test_relation_compare(self):
+        m1 = Manager(name='dogbert')
+        m2 = Manager(name='foo')
+        e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1)
+        e2 = Engineer(name='wally', primary_language='c++', reports_to=m2)
+        e3 = Engineer(name='etc', primary_language='c++')
+        sess = create_session()
+        sess.add(m1)
+        sess.add(m2)
+        sess.add(e1)
+        sess.add(e2)
+        sess.add(e3)
+        sess.flush()
+        sess.expunge_all()
+
+        self.assertEquals(
+            sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), 
+            []
+        )
+
+        self.assertEquals(
+            sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), 
+            [m1]
+        )
+
         
 
 class M2MFilterTest(ORMTest):
@@ -868,6 +926,8 @@ class M2MFilterTest(ORMTest):
         self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
 
 class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
+    keep_mappers = True
+    
     def define_tables(self, metadata):
         Base = declarative_base(metadata=metadata)
 
@@ -895,9 +955,50 @@ class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
         Child1.left_child2 = relation(Child2, secondary = secondary_table,
                primaryjoin = Parent.id == secondary_table.c.right_id,
                secondaryjoin = Parent.id == secondary_table.c.left_id,
-               uselist = False,
+               uselist = False, backref="right_children"
                                )
 
+    
+    def test_query_crit(self):
+        session = create_session()
+        c11, c12, c13 = Child1(), Child1(), Child1()
+        c21, c22, c23 = Child2(), Child2(), Child2()
+        
+        c11.left_child2 = c22
+        c12.left_child2 = c22
+        c13.left_child2 = c23
+        
+        session.add_all([c11, c12, c13, c21, c22, c23])
+        session.flush()
+        
+        # test that the join to Child2 doesn't alias Child1 in the select
+        eq_(
+            set(session.query(Child1).join(Child1.left_child2)), 
+            set([c11, c12, c13])
+        )
+
+        eq_(
+            set(session.query(Child1, Child2).join(Child1.left_child2)), 
+            set([(c11, c22), (c12, c22), (c13, c23)])
+        )
+
+        # test __eq__() on property is annotating correctly
+        eq_(
+            set(session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22)),
+            set([c22])
+        )
+
+        # test the same again
+        self.assert_compile(
+            session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22).with_labels().statement,
+            "SELECT parent.id AS parent_id, child2.id AS child2_id, parent.cls AS parent_cls FROM "
+            "secondary AS secondary_1, parent JOIN child2 ON parent.id = child2.id JOIN secondary AS secondary_2 "
+            "ON parent.id = secondary_2.left_id JOIN (SELECT parent.id AS parent_id, parent.cls AS parent_cls, "
+            "child1.id AS child1_id FROM parent JOIN child1 ON parent.id = child1.id) AS anon_1 ON "
+            "anon_1.parent_id = secondary_2.right_id WHERE anon_1.parent_id = secondary_1.right_id AND :param_1 = secondary_1.left_id",
+            dialect=default.DefaultDialect()
+        )
+
     def test_eager_join(self):
         session = create_session()
         
index 6f0b69f17a6df500dcf48f6f284363476d7fd993..a27dcfadc7a3f3d90c4e65fa705511cfa1ff9dac 100644 (file)
@@ -516,15 +516,57 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         self.assert_compile(sess.query(x).filter(x==5).statement, 
             "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect())
 
-class CompileTest(QueryTest):
+class ExpressionTest(QueryTest, AssertsCompiledSQL):
         
-    def test_deferred(self):
+    def test_deferred_instances(self):
         session = create_session()
         s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).statement
 
         l = list(session.query(User).instances(s.execute(emailad = 'jack@bean.com')))
-        assert [User(id=7)] == l
+        eq_([User(id=7)], l)
 
+
+    def test_in(self):
+        session = create_session()
+        s = session.query(User.id).join(User.addresses).group_by(User.id).having(func.count(Address.id) > 2)
+        eq_(
+            session.query(User).filter(User.id.in_(s)).all(),
+            [User(id=8)]
+        )
+
+    def test_union(self):
+        s = create_session()
+        
+        q1 = s.query(User).filter(User.name=='ed').with_labels()
+        q2 = s.query(User).filter(User.name=='fred').with_labels()
+        eq_(
+            s.query(User).from_statement(union(q1, q2).order_by('users_name')).all(),
+            [User(name='ed'), User(name='fred')]
+        )
+    
+    def test_select(self):
+        s = create_session()
+        
+        # this is actually not legal on most DBs since the subquery has no alias
+        q1 = s.query(User).filter(User.name=='ed')
+        self.assert_compile(
+            select([q1]),
+            "SELECT id, name FROM (SELECT users.id AS id, users.name AS name FROM users WHERE users.name = :name_1)",
+            dialect=default.DefaultDialect()
+        )
+        
+    def test_join(self):
+        s = create_session()
+
+        # TODO: do we want aliased() to detect a query and convert to subquery() 
+        # automatically ?
+        q1 = s.query(Address).filter(Address.email_address=='jack@bean.com')
+        adalias = aliased(Address, q1.subquery())
+        eq_(
+            s.query(User, adalias).join((adalias, User.id==adalias.user_id)).all(),
+            [(User(id=7,name=u'jack'), Address(email_address=u'jack@bean.com',user_id=7,id=1))]
+        )
+        
 # more slice tests are available in test/orm/generative.py
 class SliceTest(QueryTest):
     def test_first(self):
index 532203ce2040d8defdbe221d5cdbba30c2c0c563..9787216f73274e4eb31460e020c47ab813ce97a6 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData
+from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_
 from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers
 from testlib.testing import eq_, startswith_
 from orm import _base, _fixtures
@@ -650,6 +650,79 @@ class RelationTest6(_base.MappedTest):
             [TagInstance(data='iplc_case'), TagInstance(data='not_iplc_case')]
         )
 
+class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest):
+    """test ambiguous joins due to FKs on both sides treated as self-referential.
+    
+    this mapping is very similar to that of test/orm/inheritance/query.py
+    SelfReferentialTestJoinedToBase , except that inheritance is not used
+    here.
+    
+    """
+    
+    def define_tables(self, metadata):
+        subscriber_table = Table('subscriber', metadata,
+           Column('id', Integer, primary_key=True),
+           Column('dummy', String(10)) # to appease older sqlite version
+          )
+
+        address_table = Table('address',
+                 metadata,
+                 Column('subscriber_id', Integer, ForeignKey('subscriber.id'), primary_key=True),
+                 Column('type', String(1), primary_key=True),
+                 )
+
+    @testing.resolve_artifact_names
+    def setup_mappers(self):
+        subscriber_and_address = subscriber.join(address, 
+               and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C'])))
+
+        class Address(_base.ComparableEntity):
+            pass
+
+        class Subscriber(_base.ComparableEntity):
+            pass
+
+        mapper(Address, address)
+
+        mapper(Subscriber, subscriber_and_address, properties={
+           'id':[subscriber.c.id, address.c.subscriber_id],
+           'addresses' : relation(Address, 
+                backref=backref("customer"))
+           })
+        
+    @testing.resolve_artifact_names
+    def test_mapping(self):
+        from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
+        sess = create_session()
+        assert Subscriber.addresses.property.direction is ONETOMANY
+        assert Address.customer.property.direction is MANYTOONE
+        
+        s1 = Subscriber(type='A',
+                addresses = [
+                    Address(type='D'),
+                    Address(type='E'),
+                ]
+        )
+        a1 = Address(type='B', customer=Subscriber(type='C'))
+        
+        assert s1.addresses[0].customer is s1
+        assert a1.customer.addresses[0] is a1
+        
+        sess.add_all([s1, a1])
+        
+        sess.flush()
+        sess.expunge_all()
+        
+        eq_(
+            sess.query(Subscriber).order_by(Subscriber.type).all(),
+            [
+                Subscriber(id=1, type=u'A'), 
+                Subscriber(id=2, type=u'B'), 
+                Subscriber(id=2, type=u'C')
+            ]
+        )
+
+
 class ManualBackrefTest(_fixtures.FixtureTest):
     """Test explicit relations that are backrefs to each other."""
 
index 3cb4dfb9fab6bfecda60fe841b22efc6ac088976..4f65f1d33cb4fcec6c60b5f23241043b253fc07f 100644 (file)
@@ -6,6 +6,8 @@ from sqlalchemy.orm.session import _sessions
 import operator
 from testlib import testing
 from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, PickleType
+import sqlalchemy as sa
+from sqlalchemy.sql import column
 from orm import _base
 import sqlalchemy as sa
 from sqlalchemy.sql import column
index ac9b7e3292a831943acbbe3010752b02024e3678..1519575036901bcb97588f3d96316f739ec6d9d1 100644 (file)
@@ -37,7 +37,11 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                     GenericFunction.__init__(self, args=[arg], **kwargs)
                 
             self.assert_compile(fake_func('foo'), "fake_func(%s)" % bindtemplate % {'name':'param_1', 'position':1}, dialect=dialect)
-    
+            
+    def test_use_labels(self):
+        self.assert_compile(select([func.foo()], use_labels=True), 
+            "SELECT foo() AS foo_1"
+        )
     def test_underscores(self):
         self.assert_compile(func.if_(), "if()")
         
index 2072fb75e87334a6444c8c6ad71eda0a719bb3ea..3947a450fe7dc04c5941a64f46772eaf354ce8cc 100644 (file)
@@ -447,6 +447,13 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
 
+        self.assert_compile(vis.traverse(case([(t1.c.col1==5, t1.c.col2)], else_=t1.c.col1)), 
+            "CASE WHEN (t1alias.col1 = :col1_1) THEN t1alias.col2 ELSE t1alias.col1 END"
+        )
+        self.assert_compile(vis.traverse(case([(5, t1.c.col2)], value=t1.c.col1, else_=t1.c.col1)), 
+            "CASE t1alias.col1 WHEN :param_1 THEN t1alias.col2 ELSE t1alias.col1 END"
+        )
+
 
         s = select(['*'], from_obj=[t1]).alias('foo')
         self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
index 5a620be8c89abe0e94e883723372559a8509b5eb..94ee20342e6ac7cb5d39c98c06a3337920c773ba 100644 (file)
@@ -4,9 +4,6 @@ from sqlalchemy import exc as exceptions
 from testlib import *
 from sqlalchemy.engine import default
 
-# TODO: either create a mock dialect with named paramstyle and a short identifier length,
-# or find a way to just use sqlite dialect and make those changes
-
 IDENT_LENGTH = 29
 
 class LabelTypeTest(TestBase):
@@ -20,13 +17,18 @@ class LabelTypeTest(TestBase):
 
 class LongLabelsTest(TestBase, AssertsCompiledSQL):
     def setUpAll(self):
-        global metadata, table1, maxlen
+        global metadata, table1, table2, maxlen
         metadata = MetaData(testing.db)
         table1 = Table("some_large_named_table", metadata,
             Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True),
             Column("this_is_the_data_column", String(30))
             )
 
+        table2 = Table("table_with_exactly_29_characs", metadata,
+            Column("this_is_the_primarykey_column", Integer, Sequence("some_seq"), primary_key=True),
+            Column("this_is_the_data_column", String(30))
+            )
+
         metadata.create_all()
 
         maxlen = testing.db.dialect.max_identifier_length
@@ -87,6 +89,37 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
             (3, "data3"),
         ], repr(result)
 
+    def test_table_alias_names(self):
+        self.assert_compile(
+            table2.alias().select(),
+            "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1"
+        )
+
+        ta = table2.alias()
+        dialect = default.DefaultDialect()
+        dialect.max_identifier_length = IDENT_LENGTH
+        self.assert_compile(
+            select([table1, ta]).select_from(table1.join(ta, table1.c.this_is_the_data_column==ta.c.this_is_the_data_column)).\
+                        where(ta.c.this_is_the_data_column=='data3'),
+                        
+            "SELECT some_large_named_table.this_is_the_primarykey_column, some_large_named_table.this_is_the_data_column, "
+            "table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM "
+            "some_large_named_table JOIN table_with_exactly_29_characs AS table_with_exactly_29_c_1 ON "
+            "some_large_named_table.this_is_the_data_column = table_with_exactly_29_c_1.this_is_the_data_column "
+            "WHERE table_with_exactly_29_c_1.this_is_the_data_column = :this_is_the_data_column_1",
+            dialect=dialect
+        )
+        
+        table2.insert().execute(
+            {"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"},
+            {"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"},
+            {"this_is_the_primarykey_column":3, "this_is_the_data_column":"data3"},
+            {"this_is_the_primarykey_column":4, "this_is_the_data_column":"data4"},
+        )
+        
+        r = table2.alias().select().execute()
+        assert r.fetchall() == [(x, "data%d" % x) for x in range(1, 5)]
+        
     def test_colbinds(self):
         table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
         table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"})
@@ -153,9 +186,9 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
             "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :this_1) AS anon_1", dialect=compile_dialect)
 
         compile_dialect = default.DefaultDialect(label_length=4)
-        self.assert_compile(x, "SELECT anon_1.this_is_the_primarykey_column AS _1, anon_1.this_is_the_data_column AS _2 FROM "
+        self.assert_compile(x, "SELECT _1.this_is_the_primarykey_column AS _1, _1.this_is_the_data_column AS _2 FROM "
             "(SELECT some_large_named_table.this_is_the_primarykey_column AS _3, some_large_named_table.this_is_the_data_column AS _4 "
-            "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS anon_1", dialect=compile_dialect)
+            "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS _1", dialect=compile_dialect)
         
         
 if __name__ == '__main__':
index 782016e7d681a49152106dca4b09998c22e4d0b6..a4de6e331eab0849e65da9c3a89a1d8c02088769 100644 (file)
@@ -131,6 +131,28 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             select([ClauseList(column('a'), column('b'))]).select_from('sometable'), 
             'SELECT a, b FROM sometable'
         )
+        
+    def test_use_labels(self):
+        self.assert_compile(
+            select([table1.c.myid==5], use_labels=True),
+            "SELECT mytable.myid = :myid_1 AS anon_1 FROM mytable"
+        )
+
+        self.assert_compile(
+            select([func.foo()], use_labels=True),
+            "SELECT foo() AS foo_1"
+        )
+
+        self.assert_compile(
+            select([not_(True)], use_labels=True),
+            "SELECT NOT :param_1"       # TODO: should this make an anon label ??
+        )
+
+        self.assert_compile(
+            select([cast("data", Integer)], use_labels=True),      # this will work with plain Integer in 0.6
+            "SELECT CAST(:param_1 AS INTEGER) AS anon_1"
+        )
+        
     def test_nested_uselabels(self):
         """test nested anonymous label generation.  this
         essentially tests the ANONYMOUS_LABEL regex.
@@ -357,7 +379,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             select([x.label('foo')]),
             'SELECT a AND b AND c AS foo'
         )
-        
+    
         self.assert_compile(
             and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()"),
             "mytable.myid = :myid_1 AND mytable.name = :name_1 "\
@@ -812,20 +834,28 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
         self.assert_compile(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date_1, :to_date_2)) AS extract_1")
 
     def test_collate(self):
-        for expr in (select([table1.c.name.collate('somecol')]),
-                     select([collate(table1.c.name, 'somecol')])):
+        for expr in (select([table1.c.name.collate('latin1_german2_ci')]),
+                     select([collate(table1.c.name, 'latin1_german2_ci')])):
             self.assert_compile(
-                expr, "SELECT mytable.name COLLATE somecol FROM mytable")
+                expr, "SELECT mytable.name COLLATE latin1_german2_ci AS anon_1 FROM mytable")
 
-        expr = select([table1.c.name.collate('somecol').like('%x%')])
+        assert table1.c.name.collate('latin1_german2_ci').type is table1.c.name.type
+        
+        expr = select([table1.c.name.collate('latin1_german2_ci').label('k1')]).order_by('k1')
+        self.assert_compile(expr,"SELECT mytable.name COLLATE latin1_german2_ci AS k1 FROM mytable ORDER BY k1")
+
+        expr = select([collate('foo', 'latin1_german2_ci').label('k1')])
+        self.assert_compile(expr,"SELECT :param_1 COLLATE latin1_german2_ci AS k1")
+
+        expr = select([table1.c.name.collate('latin1_german2_ci').like('%x%')])
         self.assert_compile(expr,
-                            "SELECT mytable.name COLLATE somecol "
+                            "SELECT mytable.name COLLATE latin1_german2_ci "
                             "LIKE :param_1 AS anon_1 FROM mytable")
 
-        expr = select([table1.c.name.like(collate('%x%', 'somecol'))])
+        expr = select([table1.c.name.like(collate('%x%', 'latin1_german2_ci'))])
         self.assert_compile(expr,
                             "SELECT mytable.name "
-                            "LIKE :param_1 COLLATE somecol AS anon_1 "
+                            "LIKE :param_1 COLLATE latin1_german2_ci AS anon_1 "
                             "FROM mytable")
 
         expr = select([table1.c.name.collate('col1').like(
@@ -835,10 +865,14 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
                             "LIKE :param_1 COLLATE col2 AS anon_1 "
                             "FROM mytable")
 
-        expr = select([func.concat('a', 'b').collate('somecol').label('x')])
+        expr = select([func.concat('a', 'b').collate('latin1_german2_ci').label('x')])
         self.assert_compile(expr,
                             "SELECT concat(:param_1, :param_2) "
-                            "COLLATE somecol AS x")
+                            "COLLATE latin1_german2_ci AS x")
+
+
+        expr = select([table1.c.name]).order_by(table1.c.name.collate('latin1_german2_ci'))
+        self.assert_compile(expr, "SELECT mytable.name FROM mytable ORDER BY mytable.name COLLATE latin1_german2_ci")
 
     def test_percent_chars(self):
         t = table("table%name",