From 19fcb943c431c61024ff7548bfff96f0f4d8c67a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 25 Nov 2006 21:32:26 +0000 Subject: [PATCH] - cleanup on some instance vars in Select (is_scalar, is_subquery, _froms is __froms, removed unused 'nowait', '_text', etc) - cleaned up __repr__ on Column, AbstractTypeEngine - added standalone intersect(_all), except(_all) functions, unit tests illustrating nesting patterns [ticket:247] --- CHANGES | 4 ++ lib/sqlalchemy/ansisql.py | 4 +- lib/sqlalchemy/ext/sqlsoup.py | 4 +- lib/sqlalchemy/schema.py | 18 +++++-- lib/sqlalchemy/sql.py | 88 ++++++++++++++++++-------------- lib/sqlalchemy/types.py | 5 +- test/ext/activemapper.py | 1 + test/sql/query.py | 94 +++++++++++++++++++++++++++++++++++ 8 files changed, 173 insertions(+), 45 deletions(-) diff --git a/CHANGES b/CHANGES index 98404cdd8d..b7504b9670 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,7 @@ +0.3.2 +- added keywords for EXCEPT, INTERSECT, EXCEPT ALL, INTERSECT ALL +[ticket:247] + 0.3.1 - Engine/Pool: - some new Pool utility classes, updated docs diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index a0ce64905d..2e0fe6e347 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -302,7 +302,7 @@ class ANSICompiler(sql.Compiled): self.select_stack.append(select) for c in select._raw_columns: # TODO: make this polymorphic? - if isinstance(c, sql.Select) and c._scalar: + if isinstance(c, sql.Select) and c.is_scalar: c.accept_visitor(self) inner_columns[self.get_str(c)] = c continue @@ -319,7 +319,7 @@ class ANSICompiler(sql.Compiled): inner_columns[co._label] = l # TODO: figure this out, a ColumnClause with a select as a parent # is different from any other kind of parent - elif select.issubquery and isinstance(co, sql._ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select): + elif select.is_subquery and isinstance(co, sql._ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select): # SQLite doesnt like selecting from a subquery where the column # names look like table.colname, so add a label synonomous with # the column name diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index d83ecfb590..d3081bc237 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -324,9 +324,7 @@ def _selectable_name(selectable): if isinstance(selectable, sql.Alias): return _selectable_name(selectable.selectable) elif isinstance(selectable, sql.Select): - # sometimes a Select has itself in _froms - nonrecursive_froms = [s for s in selectable._froms if s is not selectable] - return ''.join([_selectable_name(s) for s in nonrecursive_froms]) + return ''.join([_selectable_name(s) for s in selectable.froms]) elif isinstance(selectable, schema.Table): return selectable.name.capitalize() else: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index fb6894f0bc..d9a7684e72 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -241,7 +241,7 @@ class Table(SchemaItem, sql.TableClause): [repr(self.name)] + [repr(self.metadata)] + [repr(x) for x in self.columns] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']] - , ',\n') + , ',') def __str__(self): return _get_table_key(self.name, self.schema) @@ -401,10 +401,22 @@ class Column(SchemaItem, sql._ColumnClause): fk._set_parent(self) def __repr__(self): - return "Column(%s)" % string.join( + kwarg = [] + if self.key != self.name: + kwarg.append('key') + if self._primary_key: + kwarg.append('primary_key') + if not self.nullable: + kwarg.append('nullable') + if self.onupdate: + kwarg.append('onupdate') + if self.default: + kwarg.append('default') + return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + [repr(x) for x in self.foreign_keys if x is not None] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'default', 'onupdate']] + [repr(x) for x in self.constraints] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg] , ',') def _get_parent(self): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index ce33810a52..b5faf37fef 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -9,7 +9,8 @@ from sqlalchemy import util, exceptions from sqlalchemy import types as sqltypes import string, re, random, sets -__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable'] + +__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'except_', 'except_all', 'intersect', 'intersect_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable'] def desc(column): """return a descending ORDER BY clause element, e.g.: @@ -181,6 +182,18 @@ def union(*selects, **params): def union_all(*selects, **params): return _compound_select('UNION ALL', *selects, **params) +def except_(*selects, **params): + return _compound_select('EXCEPT', *selects, **params) + +def except_all(*selects, **params): + return _compound_select('EXCEPT ALL', *selects, **params) + +def intersect(*selects, **params): + return _compound_select('INTERSECT', *selects, **params) + +def intersect_all(*selects, **params): + return _compound_select('INTERSECT ALL', *selects, **params) + def alias(*args, **params): return Alias(*args, **params) @@ -1357,7 +1370,7 @@ class _SelectBaseMixin(object): def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) def _get_from_objects(self): - if self.is_where or self._scalar: + if self.is_where or self.is_scalar: return [] else: return [self] @@ -1366,19 +1379,27 @@ class CompoundSelect(_SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): _SelectBaseMixin.__init__(self) self.keyword = keyword - self.selects = selects self.use_labels = kwargs.pop('use_labels', False) self.parens = kwargs.pop('parens', False) self.correlate = kwargs.pop('correlate', False) self.for_update = kwargs.pop('for_update', False) self.nowait = kwargs.pop('nowait', False) - self.limit = kwargs.get('limit', None) - self.offset = kwargs.get('offset', None) - for s in self.selects: + self.limit = kwargs.pop('limit', None) + self.offset = kwargs.pop('offset', None) + self.is_compound = True + self.is_where = False + self.is_scalar = False + + self.selects = selects + + for s in selects: s.group_by(None) s.order_by(None) - self.group_by(*kwargs.get('group_by', [None])) - self.order_by(*kwargs.get('order_by', [None])) + + self.group_by(*kwargs.pop('group_by', [None])) + self.order_by(*kwargs.pop('order_by', [None])) + if len(kwargs): + raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys())) self._col_map = {} name = property(lambda s:s.keyword + " statement") @@ -1420,9 +1441,9 @@ class CompoundSelect(_SelectBaseMixin, FromClause): class Select(_SelectBaseMixin, FromClause): """represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" - def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, nowait=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): _SelectBaseMixin.__init__(self) - self._froms = util.OrderedDict() + self.__froms = util.OrderedDict() self.use_labels = use_labels self.whereclause = None self.having = None @@ -1430,31 +1451,29 @@ class Select(_SelectBaseMixin, FromClause): self.limit = limit self.offset = offset self.for_update = for_update - self.nowait = nowait + self.is_compound = False # indicates that this select statement should not expand its columns # into the column clause of an enclosing select, and should instead # act like a single scalar column - self._scalar = scalar + self.is_scalar = scalar # indicates if this select statement, as a subquery, should correlate # its FROM clause to that of an enclosing select statement self.correlate = correlate # indicates if this select statement is a subquery inside another query - self.issubquery = False + self.is_subquery = False # indicates if this select statement is a subquery as a criterion # inside of a WHERE clause self.is_where = False self.distinct = distinct - self._text = None self._raw_columns = [] self._correlated = None - self._correlator = Select._CorrelatedVisitor(self, False) - self._wherecorrelator = Select._CorrelatedVisitor(self, True) - + self.__correlator = Select._CorrelatedVisitor(self, False) + self.__wherecorrelator = Select._CorrelatedVisitor(self, True) self.group_by(*(group_by or [None])) self.order_by(*(order_by or [None])) @@ -1471,10 +1490,6 @@ class Select(_SelectBaseMixin, FromClause): for f in from_obj: self.append_from(f) - def _foo(self): - raise "this is a temporary assertion while we refactor SQL to not call 'name' on non-table Selectables" - name = property(lambda s:s._foo()) #"SELECT statement") - class _CorrelatedVisitor(ClauseVisitor): """visits a clause, locates any Select clauses, and tells them that they should correlate their FROM list to that of their parent.""" @@ -1491,12 +1506,12 @@ class Select(_SelectBaseMixin, FromClause): if select is self.select: return select.is_where = self.is_where - select.issubquery = True + select.is_subquery = True select.parens = True if not select.correlate: return if getattr(select, '_correlated', None) is None: - select._correlated = self.select._froms + select._correlated = self.select._Select__froms def append_column(self, column): if _is_literal(column): @@ -1506,12 +1521,13 @@ class Select(_SelectBaseMixin, FromClause): # if the column is a Select statement itself, # accept visitor - column.accept_visitor(self._correlator) + column.accept_visitor(self.__correlator) # visit the FROM objects of the column looking for more Selects for f in column._get_from_objects(): - f.accept_visitor(self._correlator) - column._process_from_dict(self._froms, False) + f.accept_visitor(self.__correlator) + column._process_from_dict(self.__froms, False) + def _exportable_columns(self): return self._raw_columns def _proxy_column(self, column): @@ -1526,23 +1542,23 @@ class Select(_SelectBaseMixin, FromClause): def _append_condition(self, attribute, condition): if type(condition) == str: condition = _TextClause(condition) - condition.accept_visitor(self._wherecorrelator) - condition._process_from_dict(self._froms, False) + condition.accept_visitor(self.__wherecorrelator) + condition._process_from_dict(self.__froms, False) if getattr(self, attribute) is not None: setattr(self, attribute, and_(getattr(self, attribute), condition)) else: setattr(self, attribute, condition) def clear_from(self, from_obj): - self._froms[from_obj] = FromClause() + self.__froms[from_obj] = FromClause() def append_from(self, fromclause): if type(fromclause) == str: fromclause = _TextClause(fromclause) - fromclause.accept_visitor(self._correlator) - fromclause._process_from_dict(self._froms, True) + fromclause.accept_visitor(self.__correlator) + fromclause._process_from_dict(self.__froms, True) def _locate_oid_column(self): - for f in self._froms.values(): + for f in self.__froms.values(): if f is self: # we might be in our own _froms list if a column with us as the parent is attached, # which includes textual columns. @@ -1553,8 +1569,8 @@ class Select(_SelectBaseMixin, FromClause): else: return None def _get_froms(self): - return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))] - froms = property(lambda s: s._get_froms()) + return [f for f in self.__froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))] + froms = property(lambda s: s._get_froms(), doc="""a list containing all elements of the FROM clause""") def accept_visitor(self, visitor): # TODO: add contextual visit_ methods @@ -1581,7 +1597,7 @@ class Select(_SelectBaseMixin, FromClause): if self._engine is not None: return self._engine - for f in self._froms.values(): + for f in self.__froms.values(): if f is self: continue e = f.engine @@ -1657,7 +1673,7 @@ class _Update(_UpdateBase): visitor.visit_update(self) class _Delete(_UpdateBase): - def __init__(self, table, whereclause, **params): + def __init__(self, table, whereclause): self.table = table self.whereclause = whereclause diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index d7e9d8ce6b..08f2c98436 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -12,6 +12,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine', ] from sqlalchemy import util, exceptions +import inspect try: import cPickle as pickle except: @@ -37,7 +38,9 @@ class AbstractType(object): this can be useful for calling setinputsizes(), for example.""" return None - + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]])) + class TypeEngine(AbstractType): def __init__(self, *args, **params): pass diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py index e6ce063902..f87cbb46ed 100644 --- a/test/ext/activemapper.py +++ b/test/ext/activemapper.py @@ -10,6 +10,7 @@ import sqlalchemy.ext.activemapper as activemapper class testcase(testbase.PersistTest): def setUpAll(self): + sqlalchemy.clear_mappers() global Person, Preferences, Address class Person(ActiveMapper): diff --git a/test/sql/query.py b/test/sql/query.py index 96ad6ec8b1..d88b2bf83f 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -261,6 +261,100 @@ class QueryTest(PersistTest): r.close() finally: shadowed.drop() + +class CompoundTest(PersistTest): + """test compound statements like UNION, INTERSECT, particularly their ability to nest on + different databases.""" + def setUpAll(self): + global metadata, t1, t2, t3 + metadata = BoundMetaData(testbase.db) + t1 = Table('t1', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)), + Column('col3', String(40)), + Column('col4', String(30)) + ) + t2 = Table('t2', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)), + Column('col3', String(40)), + Column('col4', String(30))) + t3 = Table('t3', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(30)), + Column('col3', String(40)), + Column('col4', String(30))) + metadata.create_all() + + t1.insert().execute([ + dict(col2="t1col2r1", col3="aaa", col4="aaa"), + dict(col2="t1col2r2", col3="bbb", col4="bbb"), + dict(col2="t1col2r3", col3="ccc", col4="ccc"), + ]) + t2.insert().execute([ + dict(col2="t2col2r1", col3="aaa", col4="bbb"), + dict(col2="t2col2r2", col3="bbb", col4="ccc"), + dict(col2="t2col2r3", col3="ccc", col4="aaa"), + ]) + t3.insert().execute([ + dict(col2="t3col2r1", col3="aaa", col4="ccc"), + dict(col2="t3col2r2", col3="bbb", col4="aaa"), + dict(col2="t3col2r3", col3="ccc", col4="bbb"), + ]) + + def tearDownAll(self): + metadata.drop_all() + + def test_union(self): + (s1, s2) = ( + select([t1.c.col3, t1.c.col4], t1.c.col2.in_("t1col2r1", "t1col2r2")), + select([t2.c.col3, t2.c.col4], t2.c.col2.in_("t2col2r2", "t2col2r3")) + ) + u = union(s1, s2) + assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + + @testbase.unsupported('mysql') + def test_intersect(self): + i = intersect( + select([t2.c.col3, t2.c.col4]), + select([t2.c.col3, t2.c.col4], t2.c.col4==t3.c.col3) + ) + assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + + @testbase.unsupported('mysql') + def test_except_style1(self): + e = except_(union( + select([t1.c.col3, t1.c.col4]), + select([t2.c.col3, t2.c.col4]), + select([t3.c.col3, t3.c.col4]), + parens=True), select([t2.c.col3, t2.c.col4])) + assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + + @testbase.unsupported('mysql') + def test_except_style2(self): + e = except_(union( + select([t1.c.col3, t1.c.col4]), + select([t2.c.col3, t2.c.col4]), + select([t3.c.col3, t3.c.col4]), + ).alias('foo').select(), select([t2.c.col3, t2.c.col4])) + assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] + + @testbase.unsupported('mysql') + def test_composite(self): + u = intersect( + select([t2.c.col3, t2.c.col4]), + union( + select([t1.c.col3, t1.c.col4]), + select([t2.c.col3, t2.c.col4]), + select([t3.c.col3, t3.c.col4]), + ).alias('foo').select() + ) + assert u.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + assert u.alias('foo').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] + if __name__ == "__main__": testbase.main() -- 2.47.2