From: Mike Bayer Date: Sat, 19 Jan 2008 18:36:52 +0000 (+0000) Subject: - some expression fixup: X-Git-Tag: rel_0_4_3~98 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=840a2fabb8999b4b3807dfa55d771627656ab1db;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - some expression fixup: - the '.c.' attribute on a selectable now gets an entry for every column expression in its columns clause; previously, "unnamed" columns like functions and CASE statements weren't getting put there. Now they will, using their full string representation if no 'name' is available. - The anonymous 'label' generated for otherwise unlabeled functions and expressions now propagates outwards at compile time for expressions like select([select([func.foo()])]) - a CompositeSelect, i.e. any union(), union_all(), intersect(), etc. now asserts that each selectable contains the same number of columns. This conforms to the corresponding SQL requirement. - building on the above ideas, CompositeSelects now build up their ".c." collection based on the names present in the first selectable only; corresponding_column() now works fully for all embedded selectables. --- diff --git a/CHANGES b/CHANGES index 246853b3d9..7087bcc84a 100644 --- a/CHANGES +++ b/CHANGES @@ -8,7 +8,28 @@ CHANGES - added "ilike()" operator to column operations. compiles to ILIKE on postgres, lower(x) LIKE lower(y) on all others [ticket:727] - + + - some expression fixup: + - the '.c.' attribute on a selectable now gets an + entry for every column expression in its columns + clause; previously, "unnamed" columns like functions + and CASE statements weren't getting put there. Now + they will, using their full string representation + if no 'name' is available. + - The anonymous 'label' generated for otherwise + unlabeled functions and expressions now propagates + outwards at compile time for expressions like + select([select([func.foo()])]) + - a CompositeSelect, i.e. any union(), union_all(), + intersect(), etc. now asserts that each selectable + contains the same number of columns. This conforms + to the corresponding SQL requirement. + - building on the above ideas, CompositeSelects + now build up their ".c." collection based on + the names present in the first selectable only; + corresponding_column() now works fully for all + embedded selectables. + - orm - proper error message is raised when trying to access expired instance attributes with no session present diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8f2e3372a3..666a38d397 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -464,7 +464,7 @@ class DefaultCompiler(engine.Compiled): not isinstance(column.table, sql.Select): return column.label(column.name) elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and (not hasattr(column, 'name') or isinstance(column, sql._Function)): - return column.label(None) + return column.anon_label else: return column @@ -728,7 +728,7 @@ class DefaultCompiler(engine.Compiled): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def __str__(self): - return self.string + return self.string or '' class DDLBase(engine.SchemaIterator): def find_alterables(self, tables): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 3ebc4960fa..c603418028 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1402,17 +1402,37 @@ class ColumnElement(ClauseElement, _CompareMixin): ``ColumnElement`` as it appears in the select list of a descending selectable. - The default implementation returns a ``_ColumnClause`` if a - name is given, else just returns self. """ if name is not None: - co = _ColumnClause(name, selectable) + co = _ColumnClause(name, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] selectable.columns[name]= co return co else: - return self + name = str(self) + co = _ColumnClause(self.anon_label.name, selectable, type_=getattr(self, 'type', None)) + co.proxies = [self] + selectable.columns[name] = co + return co + + def anon_label(self): + """provides a constant 'anonymous label' for this ColumnElement. + + This is a label() expression which will be named at compile time. + The same label() is returned each time anon_label is called so + that expressions can reference anon_label multiple times, producing + the same label name at compile time. + + the compiler uses this function automatically at compile time + for expressions that are known to be 'unnamed' like binary + expressions and function calls. + """ + + if not hasattr(self, '_ColumnElement__anon_label'): + self.__anon_label = self.label(None) + return self.__anon_label + anon_label = property(anon_label) class ColumnCollection(util.OrderedProperties): """An ordered dictionary that stores a list of ColumnElement @@ -2026,15 +2046,6 @@ class _Cast(ColumnElement): def _get_from_objects(self, **modifiers): return self.clause._get_from_objects(**modifiers) - def _make_proxy(self, selectable, name=None): - if name is not None: - co = _ColumnClause(name, selectable, type_=self.type) - co.proxies = [self] - selectable.columns[name]= co - return co - else: - return self - class _UnaryExpression(ColumnElement): def __init__(self, element, operator=None, modifier=None, type_=None, negate=None): @@ -2864,8 +2875,16 @@ class CompoundSelect(_SelectBaseMixin, FromClause): self.keyword = keyword self.selects = [] + numcols = None + # some DBs do not like ORDER BY in the inner queries of a UNION, etc. for n, s in enumerate(selects): + if not numcols: + numcols = len(s.c) + elif len(s.c) != numcols: + raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + (1, len(self.selects[0].c), n+1, len(s.c)) + ) if s._order_by_clause: s = s.order_by(None) # unions group from left to right, so don't group first select @@ -2892,17 +2911,22 @@ class CompoundSelect(_SelectBaseMixin, FromClause): yield c def _proxy_column(self, column): - existing = self._col_map.get(column.name, None) - if existing is not None: - existing.proxies.append(column) - return existing - else: + selectable = column.table + col_ordering = self._col_map.get(selectable, None) + if col_ordering is None: + self._col_map[selectable] = col_ordering = [] + + if selectable is self.selects[0]: if self.use_labels: col = column._make_proxy(self, name=column._label) else: col = column._make_proxy(self) - self._col_map[col.name] = col - return col + col_ordering.append(col) + else: + col_ordering.append(column) + existing = self._col_map[self.selects[0]][len(col_ordering) - 1] + existing.proxies.append(column) + return existing def _copy_internals(self, clone=_clone): self._clone_from_clause() diff --git a/test/sql/select.py b/test/sql/select.py index 07c3ce69e6..c34cec7c51 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -845,95 +845,94 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl ) def testunion(self): - x = union( - select([table1], table1.c.myid == 5), - select([table1], table1.c.myid == 12), - order_by = [table1.c.myid], - ) + try: + union(table3.select(), table1.select()) + except exceptions.ArgumentError, err: + assert str(err) == "All selectables passed to CompoundSelect must have identical numbers of columns; select #1 has 2 columns, select #2 has 3" + + x = union( + select([table1], table1.c.myid == 5), + select([table1], table1.c.myid == 12), + order_by = [table1.c.myid], + ) - self.assert_compile(x, "SELECT mytable.myid, mytable.name, mytable.description \ + self.assert_compile(x, "SELECT mytable.myid, mytable.name, mytable.description \ FROM mytable WHERE mytable.myid = :mytable_myid_1 UNION \ SELECT mytable.myid, mytable.name, mytable.description \ FROM mytable WHERE mytable.myid = :mytable_myid_2 ORDER BY mytable.myid") - self.assert_compile( - union( - select([table1]), - select([table2]), - select([table3]) - ) - , - "SELECT mytable.myid, mytable.name, mytable.description \ + u1 = union( + select([table1.c.myid, table1.c.name]), + select([table2]), + select([table3]) + ) + self.assert_compile(u1, + "SELECT mytable.myid, mytable.name \ FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable") - u = union( - select([table1]), + assert u1.corresponding_column(table2.c.otherid) is u1.c.myid + + # TODO - why is there an extra space before the LIMIT ? + self.assert_compile( + union( + select([table1.c.myid, table1.c.name]), select([table2]), - select([table3]) + order_by=['myid'], + offset=10, + limit=5 ) - assert u.corresponding_column(table2.c.otherid) is u.c.otherid - - self.assert_compile( - union( - select([table1]), - select([table2]), - order_by=['myid'], - offset=10, - limit=5 - ) - , "SELECT mytable.myid, mytable.name, mytable.description \ + , "SELECT mytable.myid, mytable.name \ FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \ -FROM myothertable ORDER BY myid \ - LIMIT 5 OFFSET 10" - ) +FROM myothertable ORDER BY myid LIMIT 5 OFFSET 10" + ) - self.assert_compile( - union( - select([table1.c.myid, table1.c.name, func.max(table1.c.description)], table1.c.name=='name2', group_by=[table1.c.myid, table1.c.name]), - table1.select(table1.c.name=='name1') - ) - , - "SELECT mytable.myid, mytable.name, max(mytable.description) AS max_1 FROM mytable \ + self.assert_compile( + union( + select([table1.c.myid, table1.c.name, func.max(table1.c.description)], table1.c.name=='name2', group_by=[table1.c.myid, table1.c.name]), + table1.select(table1.c.name=='name1') + ) + , + "SELECT mytable.myid, mytable.name, max(mytable.description) AS max_1 FROM mytable \ WHERE mytable.name = :mytable_name_1 GROUP BY mytable.myid, mytable.name UNION SELECT mytable.myid, mytable.name, mytable.description \ FROM mytable WHERE mytable.name = :mytable_name_2" - ) + ) - self.assert_compile( - union( - select([literal(100).label('value')]), - select([literal(200).label('value')]) - ), - "SELECT :param_1 AS value UNION SELECT :param_2 AS value" - ) + self.assert_compile( + union( + select([literal(100).label('value')]), + select([literal(200).label('value')]) + ), + "SELECT :param_1 AS value UNION SELECT :param_2 AS value" + ) def test_compound_select_grouping(self): - self.assert_compile( - union_all( - select([table1.c.myid]), - union( - select([table2.c.otherid]), - select([table3.c.userid]), - ) + self.assert_compile( + union_all( + select([table1.c.myid]), + union( + select([table2.c.otherid]), + select([table3.c.userid]), ) - , - "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \ + ) + , + "SELECT mytable.myid FROM mytable UNION ALL (SELECT myothertable.otherid FROM myothertable UNION \ SELECT thirdtable.userid FROM thirdtable)" + ) + # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time + self.assert_compile( + union( + except_( + select([table2.c.otherid]), + select([table3.c.userid]), + ), + select([table1.c.myid]) ) - # This doesn't need grouping, so don't group to not give sqlite unnecessarily hard time - self.assert_compile( - union( - except_( - select([table2.c.otherid]), - select([table3.c.userid]), - ), - select([table1.c.myid]) - ) - , - "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \ + , + "SELECT myothertable.otherid FROM myothertable EXCEPT SELECT thirdtable.userid FROM thirdtable \ UNION SELECT mytable.myid FROM mytable" - ) + ) def testouterjoin(self): query = select( @@ -1253,7 +1252,40 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2") self.assert_compile(table.select(between((table.c.field == table.c.field), False, True)), "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2") - + + def test_naming(self): + s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')]) + assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg'] + + from sqlalchemy.databases.sqlite import SLNumeric + meta = MetaData() + t1 = Table('mytable', meta, Column('col1', Integer)) + + for col, key, expr, label in ( + (table1.c.name, 'name', 'mytable.name', None), + (table1.c.myid==12, 'mytable.myid = :mytable_myid_1', 'mytable.myid = :mytable_myid_1', 'anon_1'), + (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'), + (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'), + (t1.c.col1, 'col1', 'mytable.col1', None), + (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '') + ): + s1 = select([col], from_obj=getattr(col, 'table', None) or table1) + assert s1.c.keys() == [key], s1.c.keys() + + if label: + self.assert_compile(s1, "SELECT %s AS %s FROM mytable" % (expr, label)) + else: + self.assert_compile(s1, "SELECT %s FROM mytable" % (expr,)) + + s1 = select([s1]) + if label: + self.assert_compile(s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % (label, expr, label)) + elif col.table is not None: + # sqlite rule labels subquery columns + self.assert_compile(s1, "SELECT %s FROM (SELECT %s AS %s FROM mytable)" % (key,expr, key)) + else: + self.assert_compile(s1, "SELECT %s FROM (SELECT %s FROM mytable)" % (expr,expr)) + class CRUDTest(SQLCompileTest): def testinsert(self): # generic insert, will create bind params for all columns diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 45bd7d823a..8a25db1842 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -74,8 +74,6 @@ class SelectableTest(AssertMixin): j = join(a, table2) criterion = a.c.col1 == table2.c.col2 - print - print str(j) self.assert_(criterion.compare(j.onclause)) def testunion(self): @@ -213,7 +211,7 @@ class SelectableTest(AssertMixin): assert u.corresponding_column(table2.oid_column) is u.oid_column assert u.corresponding_column(s.oid_column) is u.oid_column assert u.corresponding_column(s2.oid_column) is u.oid_column - + class PrimaryKeyTest(AssertMixin): def test_join_pk_collapse_implicit(self): """test that redundant columns in a join get 'collapsed' into a minimal primary key,