From: Mike Bayer Date: Sat, 10 Nov 2007 03:02:16 +0000 (+0000) Subject: - anonymous column expressions are automatically labeled. X-Git-Tag: rel_0_4_1~34 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ea46e556f9f691735bc14885648a92e8cf7177d5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - anonymous column expressions are automatically labeled. e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1". This allows the labelname to be present in the cursor.description which can then be appropriately matched to result-column processing rules. (we can't reliably use positional tracking for result-column matches since text() expressions may represent multiple columns). - operator overloading is now controlled by TypeEngine objects - the one built-in operator overload so far is String types overloading '+' to be the string concatenation operator. User-defined types can also define their own operator overloading by overriding the adapt_operator(self, op) method. - untyped bind parameters on the right side of a binary expression will be assigned the type of the left side of the operation, to better enable the appropriate bind parameter processing to take effect [ticket:819] --- diff --git a/CHANGES b/CHANGES index 60164c23d8..282079b8b2 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,24 @@ CHANGES - Added contains operator (generates a "LIKE %%" clause). + - anonymous column expressions are automatically labeled. + e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1". + This allows the labelname to be present in the cursor.description + which can then be appropriately matched to result-column processing + rules. (we can't reliably use positional tracking for result-column + matches since text() expressions may represent multiple columns). + + - operator overloading is now controlled by TypeEngine objects - the + one built-in operator overload so far is String types overloading + '+' to be the string concatenation operator. + User-defined types can also define their own operator overloading + by overriding the adapt_operator(self, op) method. + + - untyped bind parameters on the right side of a binary expression + will be assigned the type of the left side of the operation, to better + enable the appropriate bind parameter processing to take effect + [ticket:819] + - Removed regular expression step from most statement compilations. Also fixes [ticket:833] diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 741641afcf..247ab2d419 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -410,6 +410,7 @@ class InfoCompiler(compiler.DefaultCompiler): return "" def __visit_label(self, label): + # TODO: whats this method for ? if self.select_stack: self.typemap.setdefault(label.name.lower(), label.obj.type) if self.strings[label.obj]: diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 38e8e1217d..92b454f820 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -895,13 +895,13 @@ class MSSQLCompiler(compiler.DefaultCompiler): kwargs['mssql_aliased'] = True return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - def visit_column(self, column): + def visit_column(self, column, **kwargs): if column.table is not None and not self.isupdate and not self.isdelete: # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: return self.process(t.corresponding_column(column)) - return super(MSSQLCompiler, self).visit_column(column) + return super(MSSQLCompiler, self).visit_column(column, **kwargs) def visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 00b297f973..88ac0e2026 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -109,7 +109,7 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable): +class PGArray(sqltypes.Concatenable, sqltypes.TypeEngine): def __init__(self, item_type): if isinstance(item_type, type): item_type = item_type() diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 013c5704b9..859fb796e8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1339,6 +1339,7 @@ class ResultProxy(object): for i, item in enumerate(metadata): # sqlite possibly prepending table name to colnames so strip colname = (item[0].split('.')[-1]).decode(self.dialect.encoding) + if self.context.typemap is not None: type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE)) else: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9c82cd4aa6..1401170693 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -213,17 +213,19 @@ class DefaultCompiler(engine.Compiled): def visit_grouping(self, grouping, **kwargs): return "(" + self.process(grouping.elem) + ")" - def visit_label(self, label): + def visit_label(self, label, typemap=None, column_labels=None): labelname = self._truncated_identifier("colident", label.name) - if len(self.stack) == 1 and self.stack[-1].get('select'): + if typemap is not None: self.typemap.setdefault(labelname.lower(), label.obj.type) + + if column_labels is not None: if isinstance(label.obj, sql._ColumnClause): - self.column_labels[label.obj._label] = labelname - self.column_labels[label.name] = labelname + column_labels[label.obj._label] = labelname + column_labels[label.name] = labelname return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) - def visit_column(self, column, **kwargs): + def visit_column(self, column, typemap=None, column_labels=None, **kwargs): # there is actually somewhat of a ruleset when you would *not* necessarily # want to truncate a column identifier, if its mapped to the name of a # physical column. but thats very hard to identify at this point, and @@ -234,10 +236,9 @@ class DefaultCompiler(engine.Compiled): else: name = column.name - if len(self.stack) == 1 and self.stack[-1].get('select'): - # if we are within a visit to a Select, set up the "typemap" - # for this column which is used to translate result set values - self.typemap.setdefault(name.lower(), column.type) + if typemap is not None: + typemap.setdefault(name.lower(), column.type) + if column_labels is not None: self.column_labels.setdefault(column._label, name.lower()) if column._is_oid: @@ -303,15 +304,12 @@ class DefaultCompiler(engine.Compiled): def visit_calculatedclause(self, clause, **kwargs): return self.process(clause.clause_expr) - def visit_cast(self, cast, **kwargs): - if self.stack and self.stack[-1].get('select'): - # not sure if we want to set the typemap here... - self.typemap.setdefault("CAST", cast.type) + def visit_cast(self, cast, typemap=None, **kwargs): return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) - def visit_function(self, func, **kwargs): - if self.stack and self.stack[-1].get('select'): - self.typemap.setdefault(func.name, func.type) + def visit_function(self, func, typemap=None, **kwargs): + if typemap is not None: + typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): return ".".join(func.packagenames + [func.name]) else: @@ -349,12 +347,14 @@ class DefaultCompiler(engine.Compiled): s = s + " " + self.operator_string(unary.modifier) return s - def visit_binary(self, binary, **kwargs): + def visit_binary(self, binary, typemap=None, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + return ret def operator_string(self, operator): return self.operators.get(operator, str(operator)) @@ -453,6 +453,8 @@ class DefaultCompiler(engine.Compiled): column.table is not None and \ not isinstance(column.table, sql.Select): return column.label(column.name) + elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'): + return column.label(None) else: return None @@ -462,13 +464,18 @@ class DefaultCompiler(engine.Compiled): if asfrom: stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + column_clause_args = {} elif self.stack and self.stack[-1].get('select'): stack_entry['is_subquery'] = True - + column_clause_args = {} + else: + column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels} + if self.stack and self.stack[-1].get('from'): existingfroms = self.stack[-1]['from'] else: existingfroms = None + froms = select._get_display_froms(existingfroms) correlate_froms = util.Set() @@ -492,15 +499,15 @@ class DefaultCompiler(engine.Compiled): labelname = co._label if labelname is not None: l = co.label(labelname) - inner_columns.add(self.process(l)) + inner_columns.add(self.process(l, **column_clause_args)) else: - inner_columns.add(self.process(co)) + inner_columns.add(self.process(co, **column_clause_args)) else: l = self.label_select_column(select, co) if l is not None: - inner_columns.add(self.process(l)) + inner_columns.add(self.process(l, **column_clause_args)) else: - inner_columns.add(self.process(co)) + inner_columns.add(self.process(co, **column_clause_args)) collist = string.join(inner_columns.difference(util.Set([None])), ', ') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index e066632afb..7c42ae9a23 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1200,11 +1200,7 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) - # TODO: generalize operator overloading like this out into the - # types module - if op == operators.add and isinstance(type_, (sqltypes.Concatenable)): - op = operators.concat_op - return _BinaryExpression(self.expression_element(), obj, op, type_=type_) + return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1289,7 +1285,10 @@ class _CompareMixin(ColumnOperators): return self.__compare(operators.like_op, po) def label(self, name): - """Produce a column label, i.e. `` AS ``""" + """Produce a column label, i.e. `` AS ``. + + if 'name' is None, an anonymous label name will be generated. + """ return _Label(name, self, self.type) def desc(self): @@ -1333,7 +1332,10 @@ class _CompareMixin(ColumnOperators): return _BindParamClause('literal', obj, type_=self.type, unique=True) def _check_literal(self, other): - if isinstance(other, Operators): + if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): + other.type = self.type + return other + elif isinstance(other, Operators): return other.expression_element() elif _is_literal(other): return self._bind_param(other) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index ec2d1072dc..9e1f6aa447 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -131,7 +131,14 @@ class AbstractType(object): """ return None - + + def adapt_operator(self, op): + """given an operator from the sqlalchemy.sql.operators package, + translate it to a new operator based on the semantics of this type. + + By default, returns the operator unchanged.""" + return op + def __repr__(self): return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]])) @@ -282,9 +289,14 @@ NullTypeEngine = NullType class Concatenable(object): """marks a type as supporting 'concatenation'""" - pass + def adapt_operator(self, op): + from sqlalchemy.sql import operators + if op == operators.add: + return operators.concat_op + else: + return op -class String(TypeEngine, Concatenable): +class String(Concatenable, TypeEngine): def __init__(self, length=None, convert_unicode=False): self.length = length self.convert_unicode = convert_unicode diff --git a/test/sql/generative.py b/test/sql/generative.py index 1497ecde3d..040d4766b1 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -281,6 +281,7 @@ class ClauseTest(SQLCompileTest): self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2") + def test_joins(self): """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after diff --git a/test/sql/select.py b/test/sql/select.py index 699d05faa2..f9aa21f1e2 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -230,21 +230,21 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def test_scalar_select(self): s = select([table1.c.myid], scalar=True, correlate=False) - self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable") + self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable") s = select([table1.c.myid], scalar=True) - self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") + self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable") s = select([table1.c.myid]).correlate(None).as_scalar() - self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable") + self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable") s = select([table1.c.myid]).as_scalar() - self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") + self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable") # test expressions against scalar selects - self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal") - self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal") - self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal") + self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal AS anon_1") + self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal AS anon_1") + self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal AS anon_1") self.assert_compile(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo") @@ -294,7 +294,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True) j1 = table1.join(table2, table1.c.myid==table2.c.otherid) s2 = select([table1, s1], from_obj=[j1]) - self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") + self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) AS anon_1 FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid") def testlabelcomparison(self): x = func.lala(table1.c.myid).label('foo') @@ -640,7 +640,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today def testliteral(self): self.assert_compile(select([literal("foo") + literal("bar")], from_obj=[table1]), - "SELECT :literal || :literal_1 FROM mytable") + "SELECT :literal || :literal_1 AS anon_1 FROM mytable") def testcalculatedcolumns(self): value_tbl = table('values', @@ -652,7 +652,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today self.assert_compile( select([value_tbl.c.id, (value_tbl.c.val2 - value_tbl.c.val1)/value_tbl.c.val1]), - "SELECT values.id, (values.val2 - values.val1) / values.val1 FROM values" + "SELECT values.id, (values.val2 - values.val1) / values.val1 AS anon_1 FROM values" ) self.assert_compile( @@ -1110,9 +1110,9 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE # coverage on other dialects. sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) if isinstance(dialect, type(mysql.dialect())): - self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) \nFROM casttest") + self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest") else: - self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest") + self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest") # first test with Postgres engine check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s') diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 4af96d57fe..630ecb9d53 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -3,6 +3,7 @@ import pickleable import datetime, os from sqlalchemy import * from sqlalchemy import types +from sqlalchemy.sql import operators import sqlalchemy.engine.url as url from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird from testlib import * @@ -367,7 +368,76 @@ class BinaryTest(AssertMixin): # put a number less than the typical MySQL default BLOB size return file(f).read(len) +class ExpressionTest(AssertMixin): + def setUpAll(self): + global test_table, meta + + class MyCustomType(types.TypeEngine): + def get_col_spec(self): + return "INT" + def bind_processor(self, dialect): + def process(value): + return value * 10 + return process + def result_processor(self, dialect): + def process(value): + return value / 10 + return process + def adapt_operator(self, op): + return {operators.add:operators.sub, operators.sub:operators.add}.get(op, op) + + meta = MetaData(testbase.db) + test_table = Table('test', meta, + Column('id', Integer, primary_key=True), + Column('data', String(30)), + Column('timestamp', Date), + Column('value', MyCustomType)) + + meta.create_all() + + test_table.insert().execute({'id':1, 'data':'somedata', 'timestamp':datetime.date(2007, 10, 15), 'value':25}) + + def tearDownAll(self): + meta.drop_all() + + def test_control(self): + assert testbase.db.execute("select value from test").scalar() == 250 + + assert test_table.select().execute().fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + + def test_bind_adapt(self): + expr = test_table.c.timestamp == bindparam("thedate") + assert expr.right.type.__class__ == test_table.c.timestamp.type.__class__ + + assert testbase.db.execute(test_table.select().where(expr), {"thedate":datetime.date(2007, 10, 15)}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + + expr = test_table.c.value == bindparam("somevalue") + assert expr.right.type.__class__ == test_table.c.value.type.__class__ + assert testbase.db.execute(test_table.select().where(expr), {"somevalue":25}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)] + + + def test_operator_adapt(self): + """test type-based overloading of operators""" + + # test string concatenation + expr = test_table.c.data + "somedata" + assert testbase.db.execute(select([expr])).scalar() == "somedatasomedata" + expr = test_table.c.id + 15 + assert testbase.db.execute(select([expr])).scalar() == 16 + + # test custom operator conversion + expr = test_table.c.value + 40 + assert expr.type.__class__ is test_table.c.value.type.__class__ + + # + operator converted to - + # value is calculated as: (250 - (40 * 10)) / 10 == -15 + assert testbase.db.execute(select([expr.label('foo')])).scalar() == -15 + + # this one relies upon anonymous labeling to assemble result + # processing rules on the column. + assert testbase.db.execute(select([expr])).scalar() == -15 + class DateTest(AssertMixin): def setUpAll(self): global users_with_date, insert_data