]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- anonymous column expressions are automatically labeled.
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Nov 2007 03:02:16 +0000 (03:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Nov 2007 03:02:16 +0000 (03:02 +0000)
    e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1".
    This allows the labelname to be present in the cursor.description
    which can then be appropriately matched to result-column processing
    rules. (we can't reliably use positional tracking for result-column
    matches since text() expressions may represent multiple columns).

  - operator overloading is now controlled by TypeEngine objects - the
    one built-in operator overload so far is String types overloading
    '+' to be the string concatenation operator.
    User-defined types can also define their own operator overloading
    by overriding the adapt_operator(self, op) method.

  - untyped bind parameters on the right side of a binary expression
    will be assigned the type of the left side of the operation, to better
    enable the appropriate bind parameter processing to take effect
    [ticket:819]

CHANGES
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/sql/generative.py
test/sql/select.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index 60164c23d8278c7029fb95c969c434e29cb5b946..282079b8b286d5977aa5bfaaf1a38002464807ce 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,24 @@ CHANGES
   
   - Added contains operator (generates a "LIKE %<other>%" clause).
 
+  - anonymous column expressions are automatically labeled.  
+    e.g. select([x* 5]) produces "SELECT x * 5 AS anon_1".
+    This allows the labelname to be present in the cursor.description
+    which can then be appropriately matched to result-column processing
+    rules. (we can't reliably use positional tracking for result-column 
+    matches since text() expressions may represent multiple columns).
+  
+  - operator overloading is now controlled by TypeEngine objects - the 
+    one built-in operator overload so far is String types overloading
+    '+' to be the string concatenation operator.
+    User-defined types can also define their own operator overloading
+    by overriding the adapt_operator(self, op) method.
+    
+  - untyped bind parameters on the right side of a binary expression
+    will be assigned the type of the left side of the operation, to better
+    enable the appropriate bind parameter processing to take effect
+    [ticket:819]
+    
   - Removed regular expression step from most statement compilations.
     Also fixes [ticket:833]
 
index 741641afcff7bae2fdf38a7e5fbfebe6a6eee4cc..247ab2d41902bc143b68cfe632a4531a7c7c69dd 100644 (file)
@@ -410,6 +410,7 @@ class InfoCompiler(compiler.DefaultCompiler):
         return ""
 
     def __visit_label(self, label):
+        # TODO: whats this method for ?
         if self.select_stack:
             self.typemap.setdefault(label.name.lower(), label.obj.type)
         if self.strings[label.obj]:
index 38e8e1217d541af26078a54f97373cf4e75003a2..92b454f820121fe90326171e88489874536d11e8 100644 (file)
@@ -895,13 +895,13 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         kwargs['mssql_aliased'] = True
         return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
 
-    def visit_column(self, column):
+    def visit_column(self, column, **kwargs):
         if column.table is not None and not self.isupdate and not self.isdelete:
             # translate for schema-qualified table aliases
             t = self._schema_aliased_table(column.table)
             if t is not None:
                 return self.process(t.corresponding_column(column))
-        return super(MSSQLCompiler, self).visit_column(column)
+        return super(MSSQLCompiler, self).visit_column(column, **kwargs)
 
     def visit_binary(self, binary):
         """Move bind parameters to the right-hand side of an operator, where possible."""
index 00b297f973b952bf0c4a4ceacf2e819b6ef93fbe..88ac0e2026033e57824fb25c1d2862f5631a63f1 100644 (file)
@@ -109,7 +109,7 @@ class PGBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOLEAN"
 
-class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
+class PGArray(sqltypes.Concatenable, sqltypes.TypeEngine):
     def __init__(self, item_type):
         if isinstance(item_type, type):
             item_type = item_type()
index 013c5704b9e9502d1004124bec4fc3bcd9afa3f3..859fb796e82e72d44967cbea4d676beb304bc48c 100644 (file)
@@ -1339,6 +1339,7 @@ class ResultProxy(object):
             for i, item in enumerate(metadata):
                 # sqlite possibly prepending table name to colnames so strip
                 colname = (item[0].split('.')[-1]).decode(self.dialect.encoding)
+
                 if self.context.typemap is not None:
                     type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
                 else:
index 9c82cd4aa6963fe67e35a5944a4e039b85569cbb..14011706939d2532d07b901f619b72e50d35f7b7 100644 (file)
@@ -213,17 +213,19 @@ class DefaultCompiler(engine.Compiled):
     def visit_grouping(self, grouping, **kwargs):
         return "(" + self.process(grouping.elem) + ")"
         
-    def visit_label(self, label):
+    def visit_label(self, label, typemap=None, column_labels=None):
         labelname = self._truncated_identifier("colident", label.name)
         
-        if len(self.stack) == 1 and self.stack[-1].get('select'):
+        if typemap is not None:
             self.typemap.setdefault(labelname.lower(), label.obj.type)
+            
+        if column_labels is not None:
             if isinstance(label.obj, sql._ColumnClause):
-                self.column_labels[label.obj._label] = labelname
-            self.column_labels[label.name] = labelname
+                column_labels[label.obj._label] = labelname
+            column_labels[label.name] = labelname
         return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
         
-    def visit_column(self, column, **kwargs):
+    def visit_column(self, column, typemap=None, column_labels=None, **kwargs):
         # there is actually somewhat of a ruleset when you would *not* necessarily
         # want to truncate a column identifier, if its mapped to the name of a 
         # physical column.  but thats very hard to identify at this point, and 
@@ -234,10 +236,9 @@ class DefaultCompiler(engine.Compiled):
         else:
             name = column.name
 
-        if len(self.stack) == 1 and self.stack[-1].get('select'):
-            # if we are within a visit to a Select, set up the "typemap"
-            # for this column which is used to translate result set values
-            self.typemap.setdefault(name.lower(), column.type)
+        if typemap is not None:
+            typemap.setdefault(name.lower(), column.type)
+        if column_labels is not None:    
             self.column_labels.setdefault(column._label, name.lower())
         
         if column._is_oid:
@@ -303,15 +304,12 @@ class DefaultCompiler(engine.Compiled):
     def visit_calculatedclause(self, clause, **kwargs):
         return self.process(clause.clause_expr)
 
-    def visit_cast(self, cast, **kwargs):
-        if self.stack and self.stack[-1].get('select'):
-            # not sure if we want to set the typemap here...
-            self.typemap.setdefault("CAST", cast.type)
+    def visit_cast(self, cast, typemap=None, **kwargs):
         return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
 
-    def visit_function(self, func, **kwargs):
-        if self.stack and self.stack[-1].get('select'):
-            self.typemap.setdefault(func.name, func.type)
+    def visit_function(self, func, typemap=None, **kwargs):
+        if typemap is not None:
+            typemap.setdefault(func.name, func.type)
         if not self.apply_function_parens(func):
             return ".".join(func.packagenames + [func.name])
         else:
@@ -349,12 +347,14 @@ class DefaultCompiler(engine.Compiled):
             s = s + " " + self.operator_string(unary.modifier)
         return s
         
-    def visit_binary(self, binary, **kwargs):
+    def visit_binary(self, binary, typemap=None, **kwargs):
         op = self.operator_string(binary.operator)
         if callable(op):
             return op(self.process(binary.left), self.process(binary.right))
         else:
             return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+            
+        return ret
         
     def operator_string(self, operator):
         return self.operators.get(operator, str(operator))
@@ -453,6 +453,8 @@ class DefaultCompiler(engine.Compiled):
             column.table is not None and \
             not isinstance(column.table, sql.Select):
             return column.label(column.name)
+        elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and not hasattr(column, 'name'):
+            return column.label(None)
         else:
             return None
 
@@ -462,13 +464,18 @@ class DefaultCompiler(engine.Compiled):
         
         if asfrom:
             stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+            column_clause_args = {}
         elif self.stack and self.stack[-1].get('select'):
             stack_entry['is_subquery'] = True
-
+            column_clause_args = {}
+        else:
+            column_clause_args = {'typemap':self.typemap, 'column_labels':self.column_labels}
+            
         if self.stack and self.stack[-1].get('from'):
             existingfroms = self.stack[-1]['from']
         else:
             existingfroms = None
+            
         froms = select._get_display_froms(existingfroms)
 
         correlate_froms = util.Set()
@@ -492,15 +499,15 @@ class DefaultCompiler(engine.Compiled):
                 labelname = co._label
                 if labelname is not None:
                     l = co.label(labelname)
-                    inner_columns.add(self.process(l))
+                    inner_columns.add(self.process(l, **column_clause_args))
                 else:
-                    inner_columns.add(self.process(co))
+                    inner_columns.add(self.process(co, **column_clause_args))
             else:
                 l = self.label_select_column(select, co)
                 if l is not None:
-                    inner_columns.add(self.process(l))
+                    inner_columns.add(self.process(l, **column_clause_args))
                 else:
-                    inner_columns.add(self.process(co))
+                    inner_columns.add(self.process(co, **column_clause_args))
             
         collist = string.join(inner_columns.difference(util.Set([None])), ', ')
 
index e066632afbe5505b59bb7660777368270af76847..7c42ae9a23eb83ed924dcbb371c0564465400182 100644 (file)
@@ -1200,11 +1200,7 @@ class _CompareMixin(ColumnOperators):
 
         type_ = self._compare_type(obj)
 
-        # TODO: generalize operator overloading like this out into the
-        # types module
-        if op == operators.add and isinstance(type_, (sqltypes.Concatenable)):
-            op = operators.concat_op
-        return _BinaryExpression(self.expression_element(), obj, op, type_=type_)
+        return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
 
     # a mapping of operators with the method they use, along with their negated
     # operator for comparison operators
@@ -1289,7 +1285,10 @@ class _CompareMixin(ColumnOperators):
         return self.__compare(operators.like_op, po)
 
     def label(self, name):
-        """Produce a column label, i.e. ``<columnname> AS <name>``"""
+        """Produce a column label, i.e. ``<columnname> AS <name>``.
+        
+        if 'name' is None, an anonymous label name will be generated.
+        """
         return _Label(name, self, self.type)
 
     def desc(self):
@@ -1333,7 +1332,10 @@ class _CompareMixin(ColumnOperators):
         return _BindParamClause('literal', obj, type_=self.type, unique=True)
 
     def _check_literal(self, other):
-        if isinstance(other, Operators):
+        if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
+            other.type = self.type
+            return other
+        elif isinstance(other, Operators):
             return other.expression_element()
         elif _is_literal(other):
             return self._bind_param(other)
index ec2d1072dc323262ad06e2936ea3328202dec1e8..9e1f6aa447bc8c8894bde58e29d4886febc1a7c7 100644 (file)
@@ -131,7 +131,14 @@ class AbstractType(object):
         """
 
         return None
-
+    
+    def adapt_operator(self, op):
+        """given an operator from the sqlalchemy.sql.operators package, 
+        translate it to a new operator based on the semantics of this type.
+        
+        By default, returns the operator unchanged."""
+        return op
+        
     def __repr__(self):
         return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
 
@@ -282,9 +289,14 @@ NullTypeEngine = NullType
 
 class Concatenable(object):
     """marks a type as supporting 'concatenation'"""
-    pass
+    def adapt_operator(self, op):
+        from sqlalchemy.sql import operators
+        if op == operators.add:
+            return operators.concat_op
+        else:
+            return op
     
-class String(TypeEngine, Concatenable):
+class String(Concatenable, TypeEngine):
     def __init__(self, length=None, convert_unicode=False):
         self.length = length
         self.convert_unicode = convert_unicode
index 1497ecde3d93971ec3c8f682bb9bb9d538c13d25..040d4766b1e36310762f118c26e2a482a08249fe 100644 (file)
@@ -281,6 +281,7 @@ class ClauseTest(SQLCompileTest):
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+    
 
     def test_joins(self):
         """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after
index 699d05faa2e0ddb755ff83b56e743852e2ee980b..f9aa21f1e29583355ca2f9c96bbca561ad6f6fff 100644 (file)
@@ -230,21 +230,21 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
 
     def test_scalar_select(self):
         s = select([table1.c.myid], scalar=True, correlate=False)
-        self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+        self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable")
 
         s = select([table1.c.myid], scalar=True)
-        self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+        self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable")
 
         s = select([table1.c.myid]).correlate(None).as_scalar()
-        self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+        self.assert_compile(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) AS anon_1 FROM mytable")
 
         s = select([table1.c.myid]).as_scalar()
-        self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+        self.assert_compile(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) AS anon_1 FROM myothertable")
 
         # test expressions against scalar selects
-        self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal")
-        self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal")
-        self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal")
+        self.assert_compile(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal AS anon_1")
+        self.assert_compile(select([select([table1.c.name]).as_scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal AS anon_1")
+        self.assert_compile(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal AS anon_1")
 
         self.assert_compile(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
 
@@ -294,7 +294,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True)
         j1 = table1.join(table2, table1.c.myid==table2.c.otherid)
         s2 = select([table1, s1], from_obj=[j1])
-        self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
+        self.assert_compile(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) AS anon_1 FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
 
     def testlabelcomparison(self):
         x = func.lala(table1.c.myid).label('foo')
@@ -640,7 +640,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
 
     def testliteral(self):
         self.assert_compile(select([literal("foo") + literal("bar")], from_obj=[table1]),
-            "SELECT :literal || :literal_1 FROM mytable")
+            "SELECT :literal || :literal_1 AS anon_1 FROM mytable")
 
     def testcalculatedcolumns(self):
          value_tbl = table('values',
@@ -652,7 +652,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
          self.assert_compile(
              select([value_tbl.c.id, (value_tbl.c.val2 -
      value_tbl.c.val1)/value_tbl.c.val1]),
-             "SELECT values.id, (values.val2 - values.val1) / values.val1 FROM values"
+             "SELECT values.id, (values.val2 - values.val1) / values.val1 AS anon_1 FROM values"
          )
 
          self.assert_compile(
@@ -1110,9 +1110,9 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
             # coverage on other dialects.
             sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
             if isinstance(dialect, type(mysql.dialect())):
-                self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) \nFROM casttest")
+                self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
             else:
-                self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) \nFROM casttest")
+                self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
 
         # first test with Postgres engine
         check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(literal)s')
index 4af96d57fe96e5f520866051697cc15768a425d7..630ecb9d53a711952208db21a849b4f9bb9bec99 100644 (file)
@@ -3,6 +3,7 @@ import pickleable
 import datetime, os
 from sqlalchemy import *
 from sqlalchemy import types
+from sqlalchemy.sql import operators
 import sqlalchemy.engine.url as url
 from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
 from testlib import *
@@ -367,7 +368,76 @@ class BinaryTest(AssertMixin):
         # put a number less than the typical MySQL default BLOB size
         return file(f).read(len)
 
+class ExpressionTest(AssertMixin):
+    def setUpAll(self):
+        global test_table, meta
+
+        class MyCustomType(types.TypeEngine):
+            def get_col_spec(self):
+                return "INT"
+            def bind_processor(self, dialect):
+                def process(value):
+                    return value * 10
+                return process
+            def result_processor(self, dialect):
+                def process(value):
+                    return value / 10
+                return process
+            def adapt_operator(self, op):
+                return {operators.add:operators.sub, operators.sub:operators.add}.get(op, op)
+                
+        meta = MetaData(testbase.db)
+        test_table = Table('test', meta, 
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)),
+            Column('timestamp', Date),
+            Column('value', MyCustomType))
+        
+        meta.create_all()
+        
+        test_table.insert().execute({'id':1, 'data':'somedata', 'timestamp':datetime.date(2007, 10, 15), 'value':25})
+        
+    def tearDownAll(self):
+        meta.drop_all()
+    
+    def test_control(self):
+        assert testbase.db.execute("select value from test").scalar() == 250
+        
+        assert test_table.select().execute().fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+        
+    def test_bind_adapt(self):
+        expr = test_table.c.timestamp == bindparam("thedate")
+        assert expr.right.type.__class__ == test_table.c.timestamp.type.__class__
+        
+        assert testbase.db.execute(test_table.select().where(expr), {"thedate":datetime.date(2007, 10, 15)}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+
+        expr = test_table.c.value == bindparam("somevalue")
+        assert expr.right.type.__class__ == test_table.c.value.type.__class__
+        assert testbase.db.execute(test_table.select().where(expr), {"somevalue":25}).fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+        
+
+    def test_operator_adapt(self):
+        """test type-based overloading of operators"""
+        
+        # test string concatenation
+        expr = test_table.c.data + "somedata"
+        assert testbase.db.execute(select([expr])).scalar() == "somedatasomedata"
 
+        expr = test_table.c.id + 15
+        assert testbase.db.execute(select([expr])).scalar() == 16
+
+        # test custom operator conversion
+        expr = test_table.c.value + 40
+        assert expr.type.__class__ is test_table.c.value.type.__class__
+        
+        # + operator converted to -
+        # value is calculated as: (250 - (40 * 10)) / 10 == -15
+        assert testbase.db.execute(select([expr.label('foo')])).scalar() == -15
+
+        # this one relies upon anonymous labeling to assemble result
+        # processing rules on the column.
+        assert testbase.db.execute(select([expr])).scalar() == -15
+        
 class DateTest(AssertMixin):
     def setUpAll(self):
         global users_with_date, insert_data