]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- more query tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Dec 2007 04:31:17 +0000 (04:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Dec 2007 04:31:17 +0000 (04:31 +0000)
- trying to refine some of the adaptation stuff
- query.from_statement() wont allow further generative criterion
- added a warning to columncollection when selectable is formed with
conflicting columns (only in the col export phase)
- some method rearrangement on schema/columncollection....
- property conflicting relation warning doesnt raise for concrete

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
test/orm/mapper.py
test/orm/query.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index e0df8146943df0c48a48166e0b56ef11ff39f352..6154bc3d18e631c3760b4a21b3ce44cc4b42e0ef 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -36,7 +36,10 @@ CHANGES
       of the underlying type.  Ideal for using with Unicode or Pickletype.
       TypeDecorator should now be the primary way to augment the behavior of any
       existing type including other TypeDecorator subclasses such as PickleType.
-        
+
+    - selectables (and others) will issue a warning when two columns in
+      their exported columns collection conflict based on name.
+      
     - tables with schemas can still be used in sqlite, firebird,
       schema name just gets dropped [ticket:890]
 
index f0cae49d80643847903bf35cbd2c6c77808238ed..9394e9aeadd5b91613e90362c04870404c1ac4b4 100644 (file)
@@ -399,9 +399,10 @@ class PropertyLoader(StrategizedProperty):
         # ensure the "select_mapper", if different from the regular target mapper, is compiled.
         self.mapper.get_select_mapper()._check_compile()
 
-        for inheriting in self.parent.iterate_to_root():
-            if inheriting is not self.parent and inheriting.get_property(self.key, raiseerr=False):
-                warnings.warn(RuntimeWarning("Warning: relation '%s' on mapper '%s' supercedes the same relation on inherited mapper '%s'; this can cause dependency issues during flush" % (self.key, self.parent, inheriting)))
+        if not self.parent.concrete:
+            for inheriting in self.parent.iterate_to_root():
+                if inheriting is not self.parent and inheriting.get_property(self.key, raiseerr=False):
+                    warnings.warn(RuntimeWarning("Warning: relation '%s' on mapper '%s' supercedes the same relation on inherited mapper '%s'; this can cause dependency issues during flush" % (self.key, self.parent, inheriting)))
 
         if self.association is not None:
             if isinstance(self.association, type):
index dbc62a47b97246bf19cd7f441742a887a71ad68a..0dbfdc611e135c874ca80f633a3d888fa27a9d08 100644 (file)
@@ -49,6 +49,7 @@ class Query(object):
         self._autoflush = True
         self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
         self._attributes = {}
+        self.__joinable_tables = {}
         self._current_path = ()
         self._primary_adapter=None
         self._only_load_props = None
@@ -66,7 +67,13 @@ class Query(object):
         q._statement = q._aliases = q._criterion = None
         q._order_by = q._group_by = q._distinct = False
         return q
-        
+    
+    def _no_statement(self, meth):
+        q = self._clone()
+        if q._statement:
+            raise exceptions.InvalidRequestError("Query.%s() being called on a Query with an existing full statement - can't apply criterion." % meth)
+        return q
+    
     def _clone(self):
         q = Query.__new__(Query)
         q.__dict__ = self.__dict__.copy()
@@ -322,10 +329,10 @@ class Query(object):
         
         if self._aliases is not None:
             criterion = self._aliases.adapt_clause(criterion)
-        elif self._from_obj is not self.table:
+        elif self.table not in self._get_joinable_tables():
             criterion = sql_util.ClauseAdapter(self._from_obj).traverse(criterion)
             
-        q = self._clone()
+        q = self._no_statement("filter")
         if q._criterion is not None:
             q._criterion = q._criterion & criterion
         else:
@@ -341,12 +348,14 @@ class Query(object):
         return self.filter(sql.and_(*clauses))
 
     def _get_joinable_tables(self):
-        currenttables = [self._from_obj]
-        def visit_join(join):
-            currenttables.append(join.left)
-            currenttables.append(join.right)
-        visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
-        return currenttables
+        if self._from_obj not in self.__joinable_tables:
+            currenttables = [self._from_obj]
+            def visit_join(join):
+                currenttables.append(join.left)
+                currenttables.append(join.right)
+            visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
+            self.__joinable_tables = {self._from_obj : currenttables}
+        return self.__joinable_tables[self._from_obj]
         
     def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
         if start is None:
@@ -416,7 +425,7 @@ class Query(object):
         """
         if self._column_aggregate is not None:
             raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
-        q = self._clone()
+        q = self._no_statement("aggregate")
         q._column_aggregate = (col, func)
         return q
 
@@ -484,7 +493,7 @@ class Query(object):
     def order_by(self, criterion):
         """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
 
-        q = self._clone()
+        q = self._no_statement("order_by")
         if q._order_by is False:    
             q._order_by = util.to_list(criterion)
         else:
@@ -494,7 +503,7 @@ class Query(object):
     def group_by(self, criterion):
         """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
 
-        q = self._clone()
+        q = self._no_statement("group_by")
         if q._group_by is False:    
             q._group_by = util.to_list(criterion)
         else:
@@ -514,7 +523,7 @@ class Query(object):
         if self._aliases is not None:
             criterion = self._aliases.adapt_clause(criterion)
             
-        q = self._clone()
+        q = self._no_statement("having")
         if q._having is not None:
             q._having = q._having & criterion
         else:
@@ -543,7 +552,7 @@ class Query(object):
 
     def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
         (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
-        q = self._clone()
+        q = self._no_statement("join")
         q._from_obj = clause
         q._joinpoint = mapper
         q._aliases = aliases
@@ -568,7 +577,7 @@ class Query(object):
         the root.
         """
 
-        q = self._clone()
+        q = self._no_statement("reset_joinpoint")
         q._joinpoint = q.mapper
         q._aliases = None
         return q
@@ -638,7 +647,7 @@ class Query(object):
         ``Query``.
         """
 
-        new = self._clone()
+        new = self._no_statement("distinct")
         new._distinct = True
         return new
 
@@ -828,7 +837,7 @@ class Query(object):
             if lockmode is not None:
                 q = q.with_lockmode(lockmode)
             q = q._select_context_options(populate_existing=refresh_instance is not None, version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
-            q = q.order_by(None)
+            q._order_by = None
             # call using all() to avoid LIMIT compilation complexity
             return q.all()[0]
         except IndexError:
@@ -891,23 +900,14 @@ class Query(object):
         whereclause = self._criterion
 
         from_obj = self._from_obj
-        currenttables = self._get_joinable_tables()
-        adapt_criterion = self.table not in currenttables
+        adapt_criterion = self.table not in self._get_joinable_tables()
 
-        if whereclause is not None and (self.mapper is not self.select_mapper):
-            # adapt the given WHERECLAUSE to adjust instances of this query's mapped 
-            # table to be that of our select_table,
-            # which may be the "polymorphic" selectable used by our mapper.
+        if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper):
             whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj]))
 
-            # if extra entities, adapt the criterion to those as well
-            for m in self._entities:
-                if isinstance(m, type):
-                    m = mapper.class_mapper(m)
-                if isinstance(m, mapper.Mapper):
-                    sql_util.ClauseAdapter(m.select_table).traverse(whereclause, stop_on=util.Set([m.select_table]))
+        # TODO: mappers added via add_entity(), adapt their queries also, 
+        # if those mappers are polymorphic
 
-        
         order_by = self._order_by
         if order_by is False:
             order_by = self.mapper.order_by
@@ -969,7 +969,7 @@ class Query(object):
             if adapt_criterion:
                 context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns]
                 cf = [from_obj.corresponding_column(c, raiseerr=False) or c for c in cf]
-                
+
             s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
             
             s3 = s2.alias()
index 8179810033a33e3e8f43d185bb77c5350b2a6382..15b35b96a307f040b6aec6c7706bdcbb7a46af4d 100644 (file)
@@ -493,6 +493,7 @@ class Column(SchemaItem, expression._ColumnClause):
             [repr(self.name)] + [repr(self.type)] +
             [repr(x) for x in self.foreign_keys if x is not None] +
             [repr(x) for x in self.constraints] +
+            [(self.table and "table=<%s>" % self.table.description or "")] +
             ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg])
 
     def _get_parent(self):
@@ -504,12 +505,13 @@ class Column(SchemaItem, expression._ColumnClause):
             raise exceptions.ArgumentError("this Column already has a table!")
         if not self._is_oid:
             self._pre_existing_column = table._columns.get(self.key)
-            table._columns.add(self)
+
+            table._columns.replace(self)
         else:
             self._pre_existing_column = None
             
         if self.primary_key:
-            table.primary_key.add(self)
+            table.primary_key.replace(self)
         elif self.key in table.primary_key:
             raise exceptions.ArgumentError("Trying to redefine primary-key column '%s' as a non-primary-key column on table '%s'" % (self.key, table.fullname))
             # if we think this should not raise an error, we'd instead do this:
@@ -899,19 +901,20 @@ class PrimaryKeyConstraint(Constraint):
         self.table = table
         table.primary_key = self
         for c in self.__colnames:
-            self.append_column(table.c[c])
-
+            self.add(table.c[c])
+    
     def add(self, col):
-        self.append_column(col)
+        self.columns.add(col)
+        col.primary_key=True
+    append_column = add
+    
+    def replace(self, col):
+        self.columns.replace(col)
 
     def remove(self, col):
         col.primary_key=False
         del self.columns[col.key]
 
-    def append_column(self, col):
-        self.columns.add(col)
-        col.primary_key=True
-
     def copy(self):
         return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
 
index 7caee331440bd8fd9fff7bd68345a756e3d34f09..dabc10decbe6d29491ad5dc424527b76e08eaff0 100644 (file)
@@ -27,6 +27,7 @@ to stay the same in future releases.
 
 import re
 import datetime
+import warnings
 from sqlalchemy import util, exceptions
 from sqlalchemy.sql import operators, visitors
 from sqlalchemy import types as sqltypes
@@ -1464,6 +1465,27 @@ class ColumnCollection(util.OrderedProperties):
     def __str__(self):
         return repr([str(c) for c in self])
 
+    def replace(self, column):
+        """add the given column to this collection, removing unaliased versions of this column
+           as well as existing columns with the same key.
+        
+            e.g.::
+            
+                t = Table('sometable', Column('col1', Integer))
+                t.replace_unalised(Column('col1', Integer, key='columnone'))
+                
+            will remove the original 'col1' from the collection, and add 
+            the new column under the name 'columnname'.
+            
+           Used by schema.Column to override columns during table reflection.
+        """
+        
+        if column.name in self and column.key != column.name:
+            other = self[column.name]
+            if other.name == other.key:
+                del self[other.name]
+        util.OrderedProperties.__setitem__(self, column.key, column)
+        
     def add(self, column):
         """Add a column to this collection.
 
@@ -1471,14 +1493,18 @@ class ColumnCollection(util.OrderedProperties):
         for this dictionary.
         """
 
-        # Allow an aliased column to replace an unaliased column of the
-        # same name.
-        if column.name in self:
-            other = self[column.name]
-            if other.name == other.key:
-                del self[other.name]
         self[column.key] = column
-
+    
+    def __setitem__(self, key, value):
+        if key in self:
+            # this warning is primarily to catch select() statements which have conflicting
+            # column names in their exported columns collection
+            existing = self[key]
+            if not existing.shares_lineage(value):
+                table = getattr(existing, 'table', None) and existing.table.description
+                warnings.warn(RuntimeWarning("Column %r on table %r being replaced by another column with the same key.  Consider use_labels for select() statements."  % (key, table)))
+        util.OrderedProperties.__setitem__(self, key, value)
+        
     def remove(self, column):
         del self[column.key]
 
index 160a315524bcc13be55304cb362b518f01b3bd10..65a6ad8fa016ceb59d789810115dc81df960d934 100644 (file)
@@ -1232,8 +1232,7 @@ class RequirementsTest(AssertMixin):
         t5 = Table('ht5', metadata,
                    Column('ht1_id', Integer, ForeignKey('ht1.id'),
                           primary_key=True),
-                   Column('ht1_id', Integer, ForeignKey('ht1.id'),
-                          primary_key=True))
+                    )
         t6 = Table('ht6', metadata,
                    Column('ht1a_id', Integer, ForeignKey('ht1.id'),
                           primary_key=True),
index 09c6c2144c449c1065adb073488b368ec1b7e0cf..5f85151f0aab95ccd793720110087c4ced640fc7 100644 (file)
@@ -486,7 +486,29 @@ class ParentTest(QueryTest):
 
 
 class JoinTest(QueryTest):
-
+    
+    def test_getjoinable_tables(self):
+        sess = create_session()
+        
+        sel1 = select([users]).alias()
+        sel2 = select([users], from_obj=users.join(addresses)).alias()
+        
+        j1 = sel1.join(users, sel1.c.id==users.c.id)
+        j2 = j1.join(addresses)
+        
+        for from_obj, assert_cond in (
+            (users, [users]),
+            (users.join(addresses), [users, addresses]),
+            (sel1, [sel1]),
+            (sel2, [sel2]),
+            (sel1.join(users, sel1.c.id==users.c.id), [sel1, users]),
+            (sel2.join(users, sel2.c.id==users.c.id), [sel2, users]),
+            (j2, [j1, j2, sel1, users, addresses])
+            
+        ):
+            ret = set(sess.query(User).select_from(from_obj)._get_joinable_tables())
+            self.assertEquals(ret, set(assert_cond).union([from_obj]), [x.description for x in ret])
+        
     def test_overlapping_paths(self):
         for aliased in (True,False):
             # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
@@ -995,6 +1017,18 @@ class SelectFromTest(QueryTest):
                     Order(description=u'order 5',items=[Item(description=u'item 5',keywords=[])])])
                 ])
         self.assert_sql_count(testbase.db, go, 1)
+
+        sess.clear()
+        sel2 = orders.select(orders.c.id.in_([1,2,3]))
+        self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').all(), [
+            Order(description=u'order 1',id=1), 
+            Order(description=u'order 2',id=2), 
+        ])
+        self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').all(), [
+            Order(description=u'order 1',id=1), 
+            Order(description=u'order 2',id=2), 
+        ])
+        
         
     def test_replace_with_eager(self):
         mapper(User, users, properties = {
@@ -1026,7 +1060,6 @@ class SelectFromTest(QueryTest):
             self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
         self.assert_sql_count(testbase.db, go, 1)
     
-        
 class CustomJoinTest(QueryTest):
     keep_mappers = False
 
index 49a61bf2b94d626e56112e3adb9fc7d99663c8ad..4796288dfa5e26185853b2911da1ce434d6637bb 100755 (executable)
@@ -47,7 +47,10 @@ class SelectableTest(AssertMixin):
         j2 = jjj.alias('foo')
         assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1
         
-
+    def testselectontable(self):
+        sel = select([table, table2], use_labels=True)
+        assert sel.corresponding_column(table.c.col1) is sel.c.table1_col1
+        
     def testjoinagainstjoin(self):
         j  = outerjoin(table, table2, table.c.col1==table2.c.col2)
         jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')