]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- some expression fixup:
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Jan 2008 18:36:52 +0000 (18:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Jan 2008 18:36:52 +0000 (18:36 +0000)
- the '.c.' attribute on a selectable now gets an
entry for every column expression in its columns
clause; previously, "unnamed" columns like functions
and CASE statements weren't getting put there.  Now
they will, using their full string representation
if no 'name' is available.
- The anonymous 'label' generated for otherwise
unlabeled functions and expressions now propagates
outwards at compile time for expressions like
select([select([func.foo()])])
- a CompositeSelect, i.e. any union(), union_all(),
intersect(), etc. now asserts that each selectable
contains the same number of columns.  This conforms
to the corresponding SQL requirement.
- building on the above ideas, CompositeSelects
now build up their ".c." collection based on
the names present in the first selectable only;
corresponding_column() now works fully for all
embedded selectables.

CHANGES
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/select.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index 246853b3d97d9ca955380caa96cfae3eb7b28440..7087bcc84a300529e55dfabf69eab903be195c84 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,7 +8,28 @@ CHANGES
     - added "ilike()" operator to column operations. 
       compiles to ILIKE on postgres, lower(x) LIKE lower(y)
       on all others [ticket:727]
-      
+    
+    - some expression fixup:
+        - the '.c.' attribute on a selectable now gets an
+        entry for every column expression in its columns
+        clause; previously, "unnamed" columns like functions
+        and CASE statements weren't getting put there.  Now
+        they will, using their full string representation
+        if no 'name' is available.  
+        - The anonymous 'label' generated for otherwise
+        unlabeled functions and expressions now propagates 
+        outwards at compile time for expressions like 
+        select([select([func.foo()])])
+        - a CompositeSelect, i.e. any union(), union_all(),
+        intersect(), etc. now asserts that each selectable
+        contains the same number of columns.  This conforms
+        to the corresponding SQL requirement.
+        - building on the above ideas, CompositeSelects
+        now build up their ".c." collection based on
+        the names present in the first selectable only;
+        corresponding_column() now works fully for all 
+        embedded selectables.
+        
 - orm
     - proper error message is raised when trying to access
       expired instance attributes with no session present
index 8f2e3372a3a257c91ffeed0049412102e6f1db78..666a38d397c8ae69122e32fcf57d798460d5cd25 100644 (file)
@@ -464,7 +464,7 @@ class DefaultCompiler(engine.Compiled):
             not isinstance(column.table, sql.Select):
             return column.label(column.name)
         elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and (not hasattr(column, 'name') or isinstance(column, sql._Function)):
-            return column.label(None)
+            return column.anon_label
         else:
             return column
 
@@ -728,7 +728,7 @@ class DefaultCompiler(engine.Compiled):
         return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
     
     def __str__(self):
-        return self.string
+        return self.string or ''
 
 class DDLBase(engine.SchemaIterator):
     def find_alterables(self, tables):
index 3ebc4960fac9349c47ced5ac109d8baa9330babe..c603418028f443d46302b07d5a78002efe9eeea0 100644 (file)
@@ -1402,17 +1402,37 @@ class ColumnElement(ClauseElement, _CompareMixin):
         ``ColumnElement`` as it appears in the select list of a
         descending selectable.
 
-        The default implementation returns a ``_ColumnClause`` if a
-        name is given, else just returns self.
         """
 
         if name is not None:
-            co = _ColumnClause(name, selectable)
+            co = _ColumnClause(name, selectable, type_=getattr(self, 'type', None))
             co.proxies = [self]
             selectable.columns[name]= co
             return co
         else:
-            return self
+            name = str(self)
+            co = _ColumnClause(self.anon_label.name, selectable, type_=getattr(self, 'type', None))
+            co.proxies = [self]
+            selectable.columns[name] = co
+            return co
+    
+    def anon_label(self):
+        """provides a constant 'anonymous label' for this ColumnElement.
+        
+        This is a label() expression which will be named at compile time.
+        The same label() is returned each time anon_label is called so 
+        that expressions can reference anon_label multiple times, producing
+        the same label name at compile time.
+        
+        the compiler uses this function automatically at compile time
+        for expressions that are known to be 'unnamed' like binary
+        expressions and function calls.
+        """
+
+        if not hasattr(self, '_ColumnElement__anon_label'):
+            self.__anon_label = self.label(None)
+        return self.__anon_label
+    anon_label = property(anon_label)
 
 class ColumnCollection(util.OrderedProperties):
     """An ordered dictionary that stores a list of ColumnElement
@@ -2026,15 +2046,6 @@ class _Cast(ColumnElement):
     def _get_from_objects(self, **modifiers):
         return self.clause._get_from_objects(**modifiers)
 
-    def _make_proxy(self, selectable, name=None):
-        if name is not None:
-            co = _ColumnClause(name, selectable, type_=self.type)
-            co.proxies = [self]
-            selectable.columns[name]= co
-            return co
-        else:
-            return self
-
 
 class _UnaryExpression(ColumnElement):
     def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
@@ -2864,8 +2875,16 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         self.keyword = keyword
         self.selects = []
 
+        numcols = None
+        
         # some DBs do not like ORDER BY in the inner queries of a UNION, etc.
         for n, s in enumerate(selects):
+            if not numcols:
+                numcols = len(s.c)
+            elif len(s.c) != numcols:
+                raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % 
+                    (1, len(self.selects[0].c), n+1, len(s.c))
+                )
             if s._order_by_clause:
                 s = s.order_by(None)
             # unions group from left to right, so don't group first select
@@ -2892,17 +2911,22 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
                 yield c
 
     def _proxy_column(self, column):
-        existing = self._col_map.get(column.name, None)
-        if existing is not None:
-            existing.proxies.append(column)
-            return existing
-        else:
+        selectable = column.table
+        col_ordering = self._col_map.get(selectable, None)
+        if col_ordering is None:
+            self._col_map[selectable] = col_ordering = []
+        
+        if selectable is self.selects[0]:
             if self.use_labels:
                 col = column._make_proxy(self, name=column._label)
             else:
                 col = column._make_proxy(self)
-            self._col_map[col.name] = col
-            return col
+            col_ordering.append(col)
+        else:
+            col_ordering.append(column)
+            existing = self._col_map[self.selects[0]][len(col_ordering) - 1]
+            existing.proxies.append(column)
+            return existing
 
     def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
index 07c3ce69e6ef9c1550dbbc1f50acb0aa82a9af41..c34cec7c516c18ebed4062ce6710edf8d3a326f5 100644 (file)
@@ -845,95 +845,94 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl
             )
 
     def testunion(self):
-            x = union(
-                  select([table1], table1.c.myid == 5),
-                  select([table1], table1.c.myid == 12),
-                  order_by = [table1.c.myid],
-            )
+        try:
+            union(table3.select(), table1.select())
+        except exceptions.ArgumentError, err:
+            assert str(err) == "All selectables passed to CompoundSelect must have identical numbers of columns; select #1 has 2 columns, select #2 has 3"
+    
+        x = union(
+              select([table1], table1.c.myid == 5),
+              select([table1], table1.c.myid == 12),
+              order_by = [table1.c.myid],
+        )
 
-            self.assert_compile(x, "SELECT mytable.myid, mytable.name, mytable.description \
+        self.assert_compile(x, "SELECT mytable.myid, mytable.name, mytable.description \
 FROM mytable WHERE mytable.myid = :mytable_myid_1 UNION \
 SELECT mytable.myid, mytable.name, mytable.description \
 FROM mytable WHERE mytable.myid = :mytable_myid_2 ORDER BY mytable.myid")
 
-            self.assert_compile(
-                    union(
-                        select([table1]),
-                        select([table2]),
-                        select([table3])
-                    )
-            ,
-            "SELECT mytable.myid, mytable.name, mytable.description \
+        u1 = union(
+            select([table1.c.myid, table1.c.name]),
+            select([table2]),
+            select([table3])
+        )
+        self.assert_compile(u1,
+        "SELECT mytable.myid, mytable.name \
 FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \
 FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable")
 
-            u = union(
-                select([table1]),
+        assert u1.corresponding_column(table2.c.otherid) is u1.c.myid
+
+        # TODO - why is there an extra space before the LIMIT ?
+        self.assert_compile(
+            union(
+                select([table1.c.myid, table1.c.name]),
                 select([table2]),
-                select([table3])
+                order_by=['myid'],
+                offset=10,
+                limit=5
             )
-            assert u.corresponding_column(table2.c.otherid) is u.c.otherid
-
-            self.assert_compile(
-                union(
-                    select([table1]),
-                    select([table2]),
-                    order_by=['myid'],
-                    offset=10,
-                    limit=5
-                )
-            ,    "SELECT mytable.myid, mytable.name, mytable.description \
+        ,    "SELECT mytable.myid, mytable.name \
 FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \
-FROM myothertable ORDER BY myid \
- LIMIT 5 OFFSET 10"
-            )
+FROM myothertable ORDER BY myid  LIMIT 5 OFFSET 10"
+        )
 
-            self.assert_compile(
-                union(
-                    select([table1.c.myid, table1.c.name, func.max(table1.c.description)], table1.c.name=='name2', group_by=[table1.c.myid, table1.c.name]),
-                    table1.select(table1.c.name=='name1')
-                )
-                ,
-                "SELECT mytable.myid, mytable.name, max(mytable.description) AS max_1 FROM mytable \
+        self.assert_compile(
+            union(
+                select([table1.c.myid, table1.c.name, func.max(table1.c.description)], table1.c.name=='name2', group_by=[table1.c.myid, table1.c.name]),
+                table1.select(table1.c.name=='name1')
+            )
+            ,
+            "SELECT mytable.myid, mytable.name, max(mytable.description) AS max_1 FROM mytable \
 WHERE mytable.name = :mytable_name_1 GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \
 FROM mytable WHERE mytable.name = :mytable_name_2"
-            )
+        )
 
-            self.assert_compile(
-                union(
-                    select([literal(100).label('value')]),
-                    select([literal(200).label('value')])
-                    ),
-                    "SELECT :param_1 AS value UNION SELECT :param_2 AS value"
-            )
+        self.assert_compile(
+            union(
+                select([literal(100).label('value')]),
+                select([literal(200).label('value')])
+                ),
+                "SELECT :param_1 AS value UNION SELECT :param_2 AS value"
+        )
 
 
     def test_compound_select_grouping(self):
-            self.assert_compile(
-                union_all(
-                    select([table1.c.myid]),
-                    union(
-                        select([table2.c.otherid]),
-                        select([table3.c.userid]),
-                    )
+        self.assert_compile(
+            union_all(
+                select([table1.c.myid]),
+                union(
+                    select([table2.c.otherid]),
+                    select([table3.c.userid]),
                 )
-                ,
-                "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \
+            )
+            ,
+            "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \
 SELECT thirdtable.userid FROM thirdtable)"
+        )
+        # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time
+        self.assert_compile(
+            union(
+                except_(
+                    select([table2.c.otherid]),
+                    select([table3.c.userid]),
+                ),
+                select([table1.c.myid])
             )
-            # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time
-            self.assert_compile(
-                union(
-                    except_(
-                        select([table2.c.otherid]),
-                        select([table3.c.userid]),
-                    ),
-                    select([table1.c.myid])
-                )
-                ,
-                "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \
+            ,
+            "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \
 UNION SELECT mytable.myid FROM mytable"
-            )
+        )
 
     def testouterjoin(self):
         query = select(
@@ -1253,7 +1252,40 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
             "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2")
         self.assert_compile(table.select(between((table.c.field == table.c.field), False, True)),
             "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2")
-
+    
+    def test_naming(self):
+        s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')])
+        assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg']
+
+        from sqlalchemy.databases.sqlite import SLNumeric
+        meta = MetaData()
+        t1 = Table('mytable', meta, Column('col1', Integer))
+        
+        for col, key, expr, label in (
+            (table1.c.name, 'name', 'mytable.name', None),
+            (table1.c.myid==12, 'mytable.myid = :mytable_myid_1', 'mytable.myid = :mytable_myid_1', 'anon_1'),
+            (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'),
+            (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'),
+            (t1.c.col1, 'col1', 'mytable.col1', None),
+            (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '')
+        ):
+            s1 = select([col], from_obj=getattr(col, 'table', None) or table1)
+            assert s1.c.keys() == [key], s1.c.keys()
+        
+            if label:
+                self.assert_compile(s1, "SELECT %s AS %s FROM mytable" % (expr, label))
+            else:
+                self.assert_compile(s1, "SELECT %s FROM mytable" % (expr,))
+            
+            s1 = select([s1])
+            if label:
+                self.assert_compile(s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % (label, expr, label))
+            elif col.table is not None:
+                # sqlite rule labels subquery columns
+                self.assert_compile(s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % (key,expr, key))
+            else:
+                self.assert_compile(s1, "SELECT %s FROM (SELECT %s FROM mytable)" % (expr,expr))
+                
 class CRUDTest(SQLCompileTest):
     def testinsert(self):
         # generic insert, will create bind params for all columns
index 45bd7d823a2bdfdab408d7e258753d4da3b0e00e..8a25db184260d7f375f8dde0bb4bf68f6b3d84c8 100755 (executable)
@@ -74,8 +74,6 @@ class SelectableTest(AssertMixin):
         j = join(a, table2)
 
         criterion = a.c.col1 == table2.c.col2
-        print
-        print str(j)
         self.assert_(criterion.compare(j.onclause))
 
     def testunion(self):
@@ -213,7 +211,7 @@ class SelectableTest(AssertMixin):
         assert u.corresponding_column(table2.oid_column) is u.oid_column
         assert u.corresponding_column(s.oid_column) is u.oid_column
         assert u.corresponding_column(s2.oid_column) is u.oid_column
-
+    
 class PrimaryKeyTest(AssertMixin):
     def test_join_pk_collapse_implicit(self):
         """test that redundant columns in a join get 'collapsed' into a minimal primary key,