]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- cleanup on some instance vars in Select (is_scalar, is_subquery, _froms is __froms...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Nov 2006 21:32:26 +0000 (21:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Nov 2006 21:32:26 +0000 (21:32 +0000)
- cleaned up __repr__ on Column, AbstractTypeEngine
- added standalone intersect(_all), except(_all) functions, unit tests illustrating nesting patterns [ticket:247]

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/types.py
test/ext/activemapper.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index 98404cdd8d4709c65c656f787f8f7c9a7ea426f6..b7504b9670fad58bb5321adda477bd8c2df86653 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,7 @@
+0.3.2
+- added keywords for EXCEPT, INTERSECT, EXCEPT ALL, INTERSECT ALL
+[ticket:247]
+
 0.3.1
 - Engine/Pool:
   - some new Pool utility classes, updated docs
index a0ce64905de81fceb536f17d195db9433705cba0..2e0fe6e34739a336a131e281ef049e571c9bfc2b 100644 (file)
@@ -302,7 +302,7 @@ class ANSICompiler(sql.Compiled):
         self.select_stack.append(select)
         for c in select._raw_columns:
             # TODO: make this polymorphic?
-            if isinstance(c, sql.Select) and c._scalar:
+            if isinstance(c, sql.Select) and c.is_scalar:
                 c.accept_visitor(self)
                 inner_columns[self.get_str(c)] = c
                 continue
@@ -319,7 +319,7 @@ class ANSICompiler(sql.Compiled):
                     inner_columns[co._label] = l
                 # TODO: figure this out, a ColumnClause with a select as a parent
                 # is different from any other kind of parent
-                elif select.issubquery and isinstance(co, sql._ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
+                elif select.is_subquery and isinstance(co, sql._ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
                     # SQLite doesnt like selecting from a subquery where the column
                     # names look like table.colname, so add a label synonomous with
                     # the column name
index d83ecfb590192ea0c89789fe2f227f91cc1fef8e..d3081bc237d3256b9c6e731fbcd6d15e5a77ba10 100644 (file)
@@ -324,9 +324,7 @@ def _selectable_name(selectable):
     if isinstance(selectable, sql.Alias):
         return _selectable_name(selectable.selectable)
     elif isinstance(selectable, sql.Select):
-        # sometimes a Select has itself in _froms
-        nonrecursive_froms = [s for s in selectable._froms if s is not selectable]
-        return ''.join([_selectable_name(s) for s in nonrecursive_froms])
+        return ''.join([_selectable_name(s) for s in selectable.froms])
     elif isinstance(selectable, schema.Table):
         return selectable.name.capitalize()
     else:
index fb6894f0bc06b98d975dc922e9e86cedd4c9538c..d9a7684e72fc84d6d68959e2cfb3adc858ddeb32 100644 (file)
@@ -241,7 +241,7 @@ class Table(SchemaItem, sql.TableClause):
         [repr(self.name)] + [repr(self.metadata)] +
         [repr(x) for x in self.columns] +
         ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]
-       , ',\n')
+       , ',')
     
     def __str__(self):
         return _get_table_key(self.name, self.schema)
@@ -401,10 +401,22 @@ class Column(SchemaItem, sql._ColumnClause):
         fk._set_parent(self)
             
     def __repr__(self):
-       return "Column(%s)" % string.join(
+        kwarg = []
+        if self.key != self.name:
+            kwarg.append('key')
+        if self._primary_key:
+            kwarg.append('primary_key')
+        if not self.nullable:
+            kwarg.append('nullable')
+        if self.onupdate:
+            kwarg.append('onupdate')
+        if self.default:
+            kwarg.append('default')
+        return "Column(%s)" % string.join(
         [repr(self.name)] + [repr(self.type)] +
         [repr(x) for x in self.foreign_keys if x is not None] +
-        ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'default', 'onupdate']]
+        [repr(x) for x in self.constraints] +
+        ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
        , ',')
         
     def _get_parent(self):
index ce33810a523eb007d03607c27262e8e65710be10..b5faf37fef75cff6d7ac30f434ae3b0f7f48610a 100644 (file)
@@ -9,7 +9,8 @@ from sqlalchemy import util, exceptions
 from sqlalchemy import types as sqltypes
 import string, re, random, sets
 
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable']
+
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'except_', 'except_all', 'intersect', 'intersect_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable']
 
 def desc(column):
     """return a descending ORDER BY clause element, e.g.:
@@ -181,6 +182,18 @@ def union(*selects, **params):
 def union_all(*selects, **params):
     return _compound_select('UNION ALL', *selects, **params)
 
+def except_(*selects, **params):
+    return _compound_select('EXCEPT', *selects, **params)
+
+def except_all(*selects, **params):
+    return _compound_select('EXCEPT ALL', *selects, **params)
+
+def intersect(*selects, **params):
+    return _compound_select('INTERSECT', *selects, **params)
+
+def intersect_all(*selects, **params):
+    return _compound_select('INTERSECT ALL', *selects, **params)
+
 def alias(*args, **params):
     return Alias(*args, **params)
 
@@ -1357,7 +1370,7 @@ class _SelectBaseMixin(object):
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
     def _get_from_objects(self):
-        if self.is_where or self._scalar:
+        if self.is_where or self.is_scalar:
             return []
         else:
             return [self]
@@ -1366,19 +1379,27 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
     def __init__(self, keyword, *selects, **kwargs):
         _SelectBaseMixin.__init__(self)
         self.keyword = keyword
-        self.selects = selects
         self.use_labels = kwargs.pop('use_labels', False)
         self.parens = kwargs.pop('parens', False)
         self.correlate = kwargs.pop('correlate', False)
         self.for_update = kwargs.pop('for_update', False)
         self.nowait = kwargs.pop('nowait', False)
-        self.limit = kwargs.get('limit', None)
-        self.offset = kwargs.get('offset', None)
-        for s in self.selects:
+        self.limit = kwargs.pop('limit', None)
+        self.offset = kwargs.pop('offset', None)
+        self.is_compound = True
+        self.is_where = False
+        self.is_scalar = False
+
+        self.selects = selects
+
+        for s in selects:
             s.group_by(None)
             s.order_by(None)
-        self.group_by(*kwargs.get('group_by', [None]))
-        self.order_by(*kwargs.get('order_by', [None]))
+            
+        self.group_by(*kwargs.pop('group_by', [None]))
+        self.order_by(*kwargs.pop('order_by', [None]))
+        if len(kwargs):
+            raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
         self._col_map = {}
 
     name = property(lambda s:s.keyword + " statement")
@@ -1420,9 +1441,9 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
 class Select(_SelectBaseMixin, FromClause):
     """represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
-    def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, nowait=False, engine=None, limit=None, offset=None, scalar=False, correlate=True):
+    def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True):
         _SelectBaseMixin.__init__(self)
-        self._froms = util.OrderedDict()
+        self.__froms = util.OrderedDict()
         self.use_labels = use_labels
         self.whereclause = None
         self.having = None
@@ -1430,31 +1451,29 @@ class Select(_SelectBaseMixin, FromClause):
         self.limit = limit
         self.offset = offset
         self.for_update = for_update
-        self.nowait = nowait
+        self.is_compound = False
 
         # indicates that this select statement should not expand its columns
         # into the column clause of an enclosing select, and should instead
         # act like a single scalar column
-        self._scalar = scalar
+        self.is_scalar = scalar
 
         # indicates if this select statement, as a subquery, should correlate
         # its FROM clause to that of an enclosing select statement
         self.correlate = correlate
         
         # indicates if this select statement is a subquery inside another query
-        self.issubquery = False
+        self.is_subquery = False
         
         # indicates if this select statement is a subquery as a criterion
         # inside of a WHERE clause
         self.is_where = False
         
         self.distinct = distinct
-        self._text = None
         self._raw_columns = []
         self._correlated = None
-        self._correlator = Select._CorrelatedVisitor(self, False)
-        self._wherecorrelator = Select._CorrelatedVisitor(self, True)
-
+        self.__correlator = Select._CorrelatedVisitor(self, False)
+        self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
 
         self.group_by(*(group_by or [None]))
         self.order_by(*(order_by or [None]))
@@ -1471,10 +1490,6 @@ class Select(_SelectBaseMixin, FromClause):
         for f in from_obj:
             self.append_from(f)
     
-    def _foo(self):
-        raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables"    
-    name = property(lambda s:s._foo()) #"SELECT statement")
-    
     class _CorrelatedVisitor(ClauseVisitor):
         """visits a clause, locates any Select clauses, and tells them that they should
         correlate their FROM list to that of their parent."""
@@ -1491,12 +1506,12 @@ class Select(_SelectBaseMixin, FromClause):
             if select is self.select:
                 return
             select.is_where = self.is_where
-            select.issubquery = True
+            select.is_subquery = True
             select.parens = True
             if not select.correlate:
                 return
             if getattr(select, '_correlated', None) is None:
-                select._correlated = self.select._froms
+                select._correlated = self.select._Select__froms
                 
     def append_column(self, column):
         if _is_literal(column):
@@ -1506,12 +1521,13 @@ class Select(_SelectBaseMixin, FromClause):
 
         # if the column is a Select statement itself, 
         # accept visitor
-        column.accept_visitor(self._correlator)
+        column.accept_visitor(self.__correlator)
         
         # visit the FROM objects of the column looking for more Selects
         for f in column._get_from_objects():
-            f.accept_visitor(self._correlator)
-        column._process_from_dict(self._froms, False)
+            f.accept_visitor(self.__correlator)
+        column._process_from_dict(self.__froms, False)
+
     def _exportable_columns(self):
         return self._raw_columns
     def _proxy_column(self, column):
@@ -1526,23 +1542,23 @@ class Select(_SelectBaseMixin, FromClause):
     def _append_condition(self, attribute, condition):
         if type(condition) == str:
             condition = _TextClause(condition)
-        condition.accept_visitor(self._wherecorrelator)
-        condition._process_from_dict(self._froms, False)
+        condition.accept_visitor(self.__wherecorrelator)
+        condition._process_from_dict(self.__froms, False)
         if getattr(self, attribute) is not None:
             setattr(self, attribute, and_(getattr(self, attribute), condition))
         else:
             setattr(self, attribute, condition)
 
     def clear_from(self, from_obj):
-        self._froms[from_obj] = FromClause()
+        self.__froms[from_obj] = FromClause()
         
     def append_from(self, fromclause):
         if type(fromclause) == str:
             fromclause = _TextClause(fromclause)
-        fromclause.accept_visitor(self._correlator)
-        fromclause._process_from_dict(self._froms, True)
+        fromclause.accept_visitor(self.__correlator)
+        fromclause._process_from_dict(self.__froms, True)
     def _locate_oid_column(self):
-        for f in self._froms.values():
+        for f in self.__froms.values():
             if f is self:
                 # we might be in our own _froms list if a column with us as the parent is attached,
                 # which includes textual columns. 
@@ -1553,8 +1569,8 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             return None
     def _get_froms(self):
-        return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))]
-    froms = property(lambda s: s._get_froms())
+        return [f for f in self.__froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))]
+    froms = property(lambda s: s._get_froms(), doc="""a list containing all elements of the FROM clause""")
 
     def accept_visitor(self, visitor):
         # TODO: add contextual visit_ methods
@@ -1581,7 +1597,7 @@ class Select(_SelectBaseMixin, FromClause):
         
         if self._engine is not None:
             return self._engine
-        for f in self._froms.values():
+        for f in self.__froms.values():
             if f is self:
                 continue
             e = f.engine
@@ -1657,7 +1673,7 @@ class _Update(_UpdateBase):
         visitor.visit_update(self)
 
 class _Delete(_UpdateBase):
-    def __init__(self, table, whereclause, **params):
+    def __init__(self, table, whereclause):
         self.table = table
         self.whereclause = whereclause
 
index d7e9d8ce6b96bfee2169cf65b7934a9336b27531..08f2c9843694f7651b72673411460eac2463ecdb 100644 (file)
@@ -12,6 +12,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
             ]
 
 from sqlalchemy import util, exceptions
+import inspect
 try:
     import cPickle as pickle
 except:
@@ -37,7 +38,9 @@ class AbstractType(object):
         
         this can be useful for calling setinputsizes(), for example."""
         return None
-            
+    def __repr__(self):
+        return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
+        
 class TypeEngine(AbstractType):
     def __init__(self, *args, **params):
         pass
index e6ce063902effe2bc3d899da9329afffe0e868c1..f87cbb46ed5074f2a8512b98f3f3aad31f275163 100644 (file)
@@ -10,6 +10,7 @@ import sqlalchemy.ext.activemapper as activemapper
 
 class testcase(testbase.PersistTest):
     def setUpAll(self):
+        sqlalchemy.clear_mappers()
         global Person, Preferences, Address
         
         class Person(ActiveMapper):
index 96ad6ec8b138a7e95a66106c9a45534368c5e83f..d88b2bf83fdd4c87c699883f81b14dbc97c0558d 100644 (file)
@@ -261,6 +261,100 @@ class QueryTest(PersistTest):
             r.close()
         finally:
             shadowed.drop()
+
+class CompoundTest(PersistTest):
+    """test compound statements like UNION, INTERSECT, particularly their ability to nest on
+    different databases."""
+    def setUpAll(self):
+        global metadata, t1, t2, t3
+        metadata = BoundMetaData(testbase.db)
+        t1 = Table('t1', metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('col2', String(30)),
+            Column('col3', String(40)),
+            Column('col4', String(30))
+            )
+        t2 = Table('t2', metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('col2', String(30)),
+            Column('col3', String(40)),
+            Column('col4', String(30)))
+        t3 = Table('t3', metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('col2', String(30)),
+            Column('col3', String(40)),
+            Column('col4', String(30)))
+        metadata.create_all()
+        
+        t1.insert().execute([
+            dict(col2="t1col2r1", col3="aaa", col4="aaa"),
+            dict(col2="t1col2r2", col3="bbb", col4="bbb"),
+            dict(col2="t1col2r3", col3="ccc", col4="ccc"),
+        ])
+        t2.insert().execute([
+            dict(col2="t2col2r1", col3="aaa", col4="bbb"),
+            dict(col2="t2col2r2", col3="bbb", col4="ccc"),
+            dict(col2="t2col2r3", col3="ccc", col4="aaa"),
+        ])
+        t3.insert().execute([
+            dict(col2="t3col2r1", col3="aaa", col4="ccc"),
+            dict(col2="t3col2r2", col3="bbb", col4="aaa"),
+            dict(col2="t3col2r3", col3="ccc", col4="bbb"),
+        ])
+        
+    def tearDownAll(self):
+        metadata.drop_all()
+        
+    def test_union(self):
+        (s1, s2) = (
+                    select([t1.c.col3, t1.c.col4], t1.c.col2.in_("t1col2r1", "t1col2r2")),
+            select([t2.c.col3, t2.c.col4], t2.c.col2.in_("t2col2r2", "t2col2r3"))
+        )        
+        u = union(s1, s2)
+        assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+        assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+        
+    @testbase.unsupported('mysql')
+    def test_intersect(self):
+        i = intersect(
+            select([t2.c.col3, t2.c.col4]),
+            select([t2.c.col3, t2.c.col4], t2.c.col4==t3.c.col3)
+        )
+        assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+        assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+
+    @testbase.unsupported('mysql')
+    def test_except_style1(self):
+        e = except_(union(
+            select([t1.c.col3, t1.c.col4]),
+            select([t2.c.col3, t2.c.col4]),
+            select([t3.c.col3, t3.c.col4]),
+        parens=True), select([t2.c.col3, t2.c.col4]))
+        assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+
+    @testbase.unsupported('mysql')
+    def test_except_style2(self):
+        e = except_(union(
+            select([t1.c.col3, t1.c.col4]),
+            select([t2.c.col3, t2.c.col4]),
+            select([t3.c.col3, t3.c.col4]),
+        ).alias('foo').select(), select([t2.c.col3, t2.c.col4]))
+        assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+        assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
+
+    @testbase.unsupported('mysql')
+    def test_composite(self):
+        u = intersect(
+            select([t2.c.col3, t2.c.col4]),
+            union(
+                select([t1.c.col3, t1.c.col4]),
+                select([t2.c.col3, t2.c.col4]),
+                select([t3.c.col3, t3.c.col4]),
+            ).alias('foo').select()
+        )
+        assert u.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+        assert u.alias('foo').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
+
         
 if __name__ == "__main__":
     testbase.main()