]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- decoupled all ColumnElements from also being Selectables. this means
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Aug 2007 21:50:23 +0000 (21:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Aug 2007 21:50:23 +0000 (21:50 +0000)
that anything which is a column expression does not have a "c" or a
"columns" attribute.  Also works for select().as_scalar(); _ScalarSelect
is a columnelement, so you can't say select().as_scalar().c.foo, which is
a pretty confusing mistake to make.  in the case of _ScalarSelect made
an explicit raise if you try to access 'c'.

doc/build/testdocs.py
lib/sqlalchemy/sql.py
test/sql/select.py

index 15986c5123e4f0fc7b0bc455b672f6821fdf0ce1..0f5f693300f3854b261b1cf8c862e63cecab31f7 100644 (file)
@@ -63,7 +63,6 @@ def replace_file(s, newfile):
     return s\r
 \r
 for filename in ('ormtutorial', 'sqlexpression'):\r
-#for filename in ('sqlexpression',):\r
        filename = 'content/%s.txt' % filename\r
        s = open(filename).read()\r
        #s = replace_file(s, ':memory:')\r
index bc255c5af5fa9a903632d3858f8c0c4b9f26aabe..cd9e4b4aa4519ef92732ac1bfb5c7d6d9254aa1f 100644 (file)
@@ -1549,8 +1549,7 @@ class Selectable(ClauseElement):
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
-
-class ColumnElement(Selectable, _CompareMixin):
+class ColumnElement(ClauseElement, _CompareMixin):
     """Represent an element that is useable within the 
     "column clause" portion of a ``SELECT`` statement. 
     
@@ -1582,11 +1581,6 @@ class ColumnElement(Selectable, _CompareMixin):
         which each represent a foreign key placed on this column's ultimate
         ancestor.
         """)
-    columns = property(lambda self:[self],
-                       doc=\
-        """Columns accessor which returns ``self``, to provide compatibility 
-        with ``Selectable`` objects.
-        """)
 
     def _one_fkey(self):
         if self._foreign_keys:
@@ -1909,13 +1903,13 @@ class FromClause(Selectable):
         """return the list of ColumnElements represented within this FromClause's _exportable_columns"""
         export = self._exportable_columns()
         for column in export:
-            # TODO: is this conditional needed ?
             if isinstance(column, Selectable):
-                s = column
+                for co in column.columns:
+                    yield co
+            elif isinstance(column, ColumnElement):
+                yield column
             else:
                 continue
-            for co in s.columns:
-                yield co
         
     def _exportable_columns(self):
         return []
@@ -2219,7 +2213,8 @@ class _Function(_CalculatedClause, FromClause):
             self.append(c)
 
     key = property(lambda self:self.name)
-
+    columns = property(lambda self:[self])
+    
     def _copy_internals(self):
         _CalculatedClause._copy_internals(self)
         self._clone_from_clause()
@@ -2353,18 +2348,21 @@ class _Exists(_UnaryExpression):
     
     def __init__(self, *args, **kwargs):
         kwargs['correlate'] = True
-        s = select(*args, **kwargs).self_group()
+        s = select(*args, **kwargs).as_scalar().self_group()
         _UnaryExpression.__init__(self, s, operator=Operators.exists)
 
+    def select(self, whereclauses = None, **params):
+        return select([self], whereclauses, **params)
+
     def correlate(self, fromclause):
-      e = self._clone()
-      e.element = self.element.correlate(fromclause).self_group()
-      return e
+        e = self._clone()
+        e.element = self.element.correlate(fromclause).self_group()
+        return e
     
     def where(self, clause):
-      e = self._clone()
-      e.element = self.element.where(clause).self_group()
-      return e
+        e = self._clone()
+        e.element = self.element.where(clause).self_group()
+        return e
       
     def _hide_froms(self, **modifiers):
         return self._get_from_objects(**modifiers)
@@ -2427,7 +2425,7 @@ class Join(FromClause):
     primary_key = property(lambda s:s.__primary_key)
 
     def self_group(self, against=None):
-        return _Grouping(self)
+        return _FromGrouping(self)
         
     def _locate_oid_column(self):
         return self.left.oid_column
@@ -2639,7 +2637,6 @@ class _ColumnElementAdapter(ColumnElement):
         
     key = property(lambda s: s.elem.key)
     _label = property(lambda s: s.elem._label)
-    columns = c = property(lambda s:s.elem.columns)
 
     def _copy_internals(self):
         self.elem = self.elem._clone()
@@ -2657,8 +2654,33 @@ class _ColumnElementAdapter(ColumnElement):
         return getattr(self.elem, attr)
 
 class _Grouping(_ColumnElementAdapter):
+    """represent a grouping within a column expression"""
     pass
 
+class _FromGrouping(FromClause):
+    """represent a grouping of a FROM clause"""
+    __visit_name__ = 'grouping'
+
+    def __init__(self, elem):
+        self.elem = elem
+
+    columns = c = property(lambda s:s.elem.columns)
+
+    def get_children(self, **kwargs):
+        return self.elem,
+
+    def _hide_froms(self, **modifiers):
+        return self.elem._hide_froms(**modifiers)
+
+    def _copy_internals(self):
+        self.elem = self.elem._clone()
+
+    def _get_from_objects(self, **modifiers):
+        return self.elem._get_from_objects(**modifiers)
+
+    def __getattr__(self, attr):
+        return getattr(self.elem, attr)
+    
 class _Label(ColumnElement):
     """represent a label, as typically applied to any column-level element
     using the ``AS`` sql keyword.
@@ -2698,7 +2720,7 @@ class _Label(ColumnElement):
         return self.obj._hide_froms(**modifiers)
         
     def _make_proxy(self, selectable, name = None):
-        if isinstance(self.obj, Selectable):
+        if isinstance(self.obj, (Selectable, ColumnElement)):
             return self.obj._make_proxy(selectable, name=self.name)
         else:
             return column(self.name)._make_proxy(selectable=selectable)
@@ -2979,7 +3001,10 @@ class _ScalarSelect(_Grouping):
         super(_ScalarSelect, self).__init__(elem)
         self.type = list(elem.inner_columns)[0].type
 
-    columns = property(lambda self:[self])
+    def _no_cols(self):
+        raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
+    c = property(_no_cols)
+    columns = c
     
     def self_group(self, **kwargs):
         return self
@@ -3013,7 +3038,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
     name = property(lambda s:s.keyword + " statement")
 
     def self_group(self, against=None):
-        return _Grouping(self)
+        return _FromGrouping(self)
 
     def _locate_oid_column(self):
         return self.selects[0].oid_column
@@ -3143,6 +3168,11 @@ class Select(_SelectBaseMixin, FromClause):
             return froms
     
     froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
+
+    name = property(lambda self:"Select statement")
+
+    def expression_element(self):
+        return self.as_scalar()
     
     def locate_all_froms(self):
         froms = util.Set()
@@ -3346,7 +3376,7 @@ class Select(_SelectBaseMixin, FromClause):
         self._froms.add(fromclause)
 
     def _exportable_columns(self):
-        return [c for c in self._raw_columns if isinstance(c, Selectable)]
+        return [c for c in self._raw_columns if isinstance(c, (Selectable, ColumnElement))]
         
     def _proxy_column(self, column):
         if self.use_labels:
@@ -3357,7 +3387,7 @@ class Select(_SelectBaseMixin, FromClause):
     def self_group(self, against=None):
         if isinstance(against, CompoundSelect):
             return self
-        return _Grouping(self)
+        return _FromGrouping(self)
 
     def _locate_oid_column(self):
         for f in self.locate_all_froms():
@@ -3405,14 +3435,13 @@ class Select(_SelectBaseMixin, FromClause):
                 return e
         # look through the columns (largely synomous with looking
         # through the FROMs except in the case of _CalculatedClause/_Function)
-        for cc in self._exportable_columns():
-            for c in cc.columns:
-                if getattr(c, 'table', None) is self:
-                    continue
-                e = c.bind
-                if e is not None:
-                    self._bind = e
-                    return e
+        for c in self._exportable_columns():
+            if getattr(c, 'table', None) is self:
+                continue
+            e = c.bind
+            if e is not None:
+                self._bind = e
+                return e
         return None
 
 class _UpdateBase(ClauseElement):
index 0edaab071c3e2c74aaa3e8ddc144ab070f15b496..32a889b489f4a90f86877f5c7afcc93fc18b8d6d 100644 (file)
@@ -64,6 +64,20 @@ class SQLTest(PersistTest):
                 self.assert_(c.get_params().get_original_dict() == checkparams, "params dont match" + repr(c.get_params()))
             
 class SelectTest(SQLTest):
+    
+    def test_attribute_sanity(self):
+        assert hasattr(table1, 'c')
+        assert hasattr(table1.select(), 'c')
+        assert not hasattr(table1.c.myid.self_group(), 'columns')
+        assert hasattr(table1.select().self_group(), 'columns')
+        assert not hasattr(table1.select().as_scalar().self_group(), 'columns')
+        assert not hasattr(table1.c.myid, 'columns')
+        assert not hasattr(table1.c.myid, 'c')
+        assert not hasattr(table1.select().c.myid, 'c')
+        assert not hasattr(table1.select().c.myid, 'columns')
+        assert not hasattr(table1.alias().c.myid, 'columns')
+        assert not hasattr(table1.alias().c.myid, 'c')
+        
     def testtableselect(self):
         self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
 
@@ -144,7 +158,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""")
     
     def testexistsascolumnclause(self):
-        self.runtest(exists([table1.c.myid], table1.c.myid==5).select(), "SELECT EXISTS (SELECT mytable.myid AS myid FROM mytable WHERE mytable.myid = :mytable_myid)", params={'mytable_myid':5})
+        self.runtest(exists([table1.c.myid], table1.c.myid==5).select(), "SELECT EXISTS (SELECT mytable.myid FROM mytable WHERE mytable.myid = :mytable_myid)", params={'mytable_myid':5})
 
         self.runtest(select([table1, exists([1], from_obj=[table2])]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) FROM mytable", params={})
 
@@ -177,10 +191,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             select([users, s.c.street], from_obj=[s]),
             """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
 
-        # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet.
-        #self.runtest(
-        #    table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), ""
-        #)
+        self.runtest(
+            table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), 
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.name = :mytable_name)"
+        )
         
         self.runtest(
             table1.select(table1.c.myid == select([table2.c.otherid], table1.c.name == table2.c.othername)),
@@ -224,7 +238,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
         
         
-    def testcolumnsubquery(self):
+    def test_scalar_select(self):
         s = select([table1.c.myid], scalar=True, correlate=False)
         self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
 
@@ -244,7 +258,18 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
 
         self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
 
+        # scalar selects should not have any attributes on their 'c' or 'columns' attribute
+        s = select([table1.c.myid]).as_scalar()
+        try:
+            s.c.foo
+        except exceptions.InvalidRequestError, err:
+            assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
 
+        try:
+            s.columns.foo
+        except exceptions.InvalidRequestError, err:
+            assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
+        
         zips = table('zips',
             column('zipcode'),
             column('latitude'),