]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
working on getting operators/left hand type awareness into the "bind" coercion. ...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Mar 2010 23:27:35 +0000 (18:27 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Mar 2010 23:27:35 +0000 (18:27 -0500)
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/sql/test_types.py

index 1c3961f1f78f462d00857e47927e65b6a6d8eba1..49ec34ab22bfa88ea37f98698b2985ebfe2d2e53 100644 (file)
@@ -1443,7 +1443,7 @@ class _CompareMixin(ColumnOperators):
             else:
                 raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
-            obj = self._check_literal(obj)
+            obj = self._check_literal(op, obj)
 
         if reverse:
             return _BinaryExpression(obj, 
@@ -1459,7 +1459,7 @@ class _CompareMixin(ColumnOperators):
                             negate=negate, modifiers=kwargs)
 
     def __operate(self, op, obj, reverse=False):
-        obj = self._check_literal(obj)
+        obj = self._check_literal(op, obj)
         
         if reverse:
             left, right = obj, self
@@ -1532,7 +1532,7 @@ class _CompareMixin(ColumnOperators):
                         "in() function accepts either a list of non-selectable values, "
                         "or a selectable: %r" % o)
             else:
-                o = self._bind_param(o)
+                o = self._bind_param(op, o)
             args.append(o)
 
         if len(args) == 0:
@@ -1558,7 +1558,7 @@ class _CompareMixin(ColumnOperators):
         # use __radd__ to force string concat behavior
         return self.__compare(
             operators.like_op,
-            literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)),
+            literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(operators.like_op, other)),
             escape=escape)
 
     def endswith(self, other, escape=None):
@@ -1566,7 +1566,7 @@ class _CompareMixin(ColumnOperators):
 
         return self.__compare(
             operators.like_op,
-            literal_column("'%'", type_=sqltypes.String) + self._check_literal(other),
+            literal_column("'%'", type_=sqltypes.String) + self._check_literal(operators.like_op, other),
             escape=escape)
 
     def contains(self, other, escape=None):
@@ -1575,7 +1575,7 @@ class _CompareMixin(ColumnOperators):
         return self.__compare(
             operators.like_op,
             literal_column("'%'", type_=sqltypes.String) +
-                self._check_literal(other) +
+                self._check_literal(operators.like_op, other) +
                 literal_column("'%'", type_=sqltypes.String),
             escape=escape)
 
@@ -1585,7 +1585,7 @@ class _CompareMixin(ColumnOperators):
         The allowed contents of ``other`` are database backend specific.
 
         """
-        return self.__compare(operators.match_op, self._check_literal(other))
+        return self.__compare(operators.match_op, self._check_literal(operators.match_op, other))
 
     def label(self, name):
         """Produce a column label, i.e. ``<columnname> AS <name>``.
@@ -1615,8 +1615,8 @@ class _CompareMixin(ColumnOperators):
         return _BinaryExpression(
                 self,
                 ClauseList(
-                    self._check_literal(cleft),
-                    self._check_literal(cright),
+                    self._check_literal(operators.and_, cleft),
+                    self._check_literal(operators.and_, cright),
                     operator=operators.and_,
                     group=False),
                 operators.between_op)
@@ -1651,17 +1651,18 @@ class _CompareMixin(ColumnOperators):
         """
         return lambda other: self.__operate(operator, other)
 
-    def _bind_param(self, obj):
-        return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
+    def _bind_param(self, operator, obj):
+        return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
 
-    def _check_literal(self, other):
-        if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
+    def _check_literal(self, operator, other):
+        if isinstance(other, _BindParamClause) and \
+            isinstance(other.type, sqltypes.NullType):
             other.type = self.type
             return other
         elif hasattr(other, '__clause_element__'):
             return other.__clause_element__()
         elif not isinstance(other, ClauseElement):
-            return self._bind_param(other)
+            return self._bind_param(operator, other)
         elif isinstance(other, (_SelectBaseMixin, Alias)):
             return other.as_scalar()
         else:
@@ -2108,7 +2109,8 @@ class _BindParamClause(ColumnElement):
 
     def __init__(self, key, value, type_=None, unique=False, 
                             isoutparam=False, required=False, 
-                            _fallback_type=None):
+                            _compared_to_operator=None,
+                            _compared_to_type=None):
         """Construct a _BindParamClause.
 
         key
@@ -2154,9 +2156,10 @@ class _BindParamClause(ColumnElement):
         self.required = required
         
         if type_ is None:
-            self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE)
-            if _fallback_type and _fallback_type._type_affinity == self.type._type_affinity:
-                self.type = _fallback_type
+            if _compared_to_type is not None:
+                self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value)
+            else:
+                self.type = sqltypes.NULLTYPE
         elif isinstance(type_, type):
             self.type = type_()
         else:
@@ -2434,9 +2437,9 @@ class _Tuple(ClauseList, ColumnElement):
     def _select_iterable(self):
         return (self, )
 
-    def _bind_param(self, obj):
+    def _bind_param(self, operator, obj):
         return _Tuple(*[
-            _BindParamClause(None, o, _fallback_type=self.type, unique=True)
+            _BindParamClause(None, o, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
             for o in obj
         ]).self_group()
     
@@ -2538,8 +2541,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
     def execute(self):
         return self.select().execute()
 
-    def _bind_param(self, obj):
-        return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
+    def _bind_param(self, operator, obj):
+        return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
 
     
 class Function(FunctionElement):
@@ -2555,8 +2558,8 @@ class Function(FunctionElement):
         
         FunctionElement.__init__(self, *clauses, **kw)
 
-    def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
+    def _bind_param(self, operator, obj):
+        return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
 
 
 class _Cast(ColumnElement):
@@ -3165,8 +3168,8 @@ class ColumnClause(_Immutable, ColumnElement):
         else:
             return []
 
-    def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
+    def _bind_param(self, operator, obj):
+        return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
 
     def _make_proxy(self, selectable, name=None, attach=True):
         # propagate the "is_literal" flag only if we are keeping our name,
index cdbf7927efb7ae77b5562c8595d73246365b9f92..4d6a28aadcdc1aebf268c5af5f8d8a4081d1064d 100644 (file)
@@ -116,6 +116,13 @@ class AbstractType(Visitable):
                 typ = t
         else:
             return self.__class__
+    
+    def _coerce_compared_value(self, op, value):
+        _coerced_type = type_map.get(type(value), NULLTYPE)
+        if _coerced_type._type_affinity == self._type_affinity:
+            return self
+        else:
+            return _coerced_type
         
     def _compare_type_affinity(self, other):
         return self._type_affinity is other._type_affinity
@@ -239,7 +246,7 @@ class TypeDecorator(AbstractType):
           # strips it off on the way out.
 
           impl = types.Unicode
-
+          
           def process_bind_param(self, value, dialect):
               return "PREFIX:" + value
 
@@ -255,6 +262,44 @@ class TypeDecorator(AbstractType):
     given; in this case, the "impl" variable can reference
     ``TypeEngine`` as a placeholder.
 
+    Types that receive a Python type that isn't similar to the 
+    ultimate type used may want to define the :meth:`TypeDecorator.coerce_compared_value`
+    method=.  This is used to give the expression system a hint 
+    when coercing Python objects
+    into bind parameters within expressions.  Consider this expression::
+    
+        mytable.c.somecol + datetime.date(2009, 5, 15)
+        
+    Above, if "somecol" is an ``Integer`` variant, it makes sense that 
+    we doing date arithmetic, where above is usually interpreted
+    by databases as adding a number of days to the given date. 
+    The expression system does the right thing by not attempting to
+    coerce the "date()" value into an integer-oriented bind parameter.
+    
+    However, suppose "somecol" is a ``TypeDecorator`` that is wrapping
+    an ``Integer``, and our ``TypeDecorator`` is actually storing dates
+    as an "epoch", i.e. a total number of days from a fixed starting
+    date.  So in this case, we *do* want the expression system to wrap 
+    the date() into our ``TypeDecorator`` type's system of coercing 
+    dates into integers.   So we would want to define::
+    
+        class MyEpochType(types.TypeDecorator):
+            impl = types.Integer
+            
+            epoch = datetime.date(1970, 1, 1)
+            
+            def process_bind_param(self, value, dialect):
+                return (value - self.epoch).days
+            
+            def process_result_value(self, value, dialect):
+                return self.epoch + timedelta(days=value)
+                
+            def coerce_compared_value(self, op, value):
+                if isinstance(value, datetime.date):
+                    return Date
+                else:
+                    raise ValueError("Python date expected.")
+
     The reason that type behavior is modified using class decoration
     instead of subclassing is due to the way dialect specific types
     are used.  Such as with the example above, when using the mysql
@@ -365,7 +410,13 @@ class TypeDecorator(AbstractType):
             return process
         else:
             return self.impl.result_processor(dialect, coltype)
+    
+    def coerce_compared_value(self, op, value):
+        return self.impl._coerce_compared_value(op, value)
 
+    def _coerce_compared_value(self, op, value):
+        return self.coerce_compared_value(op, value)
+        
     def copy(self):
         instance = self.__class__.__new__(self.__class__)
         instance.__dict__.update(self.__dict__)
@@ -384,6 +435,11 @@ class TypeDecorator(AbstractType):
     def is_mutable(self):
         return self.impl.is_mutable()
 
+    def _adapt_expression(self, op, othertype):
+        return self.impl._adapt_expression(op, othertype)
+
+
+
 class MutableType(object):
     """A mixin that marks a Type as holding a mutable object.
 
@@ -461,7 +517,7 @@ class Concatenable(object):
     """A mixin that marks a type as supporting 'concatenation', typically strings."""
 
     def _adapt_expression(self, op, othertype):
-        if op is operators.add and isinstance(othertype, (Concatenable, NullType)):
+        if op is operators.add and issubclass(othertype._type_affinity, (Concatenable, NullType)):
             return operators.concat_op, self
         else:
             return op, self
index 3ac8baf00464f2c175b74724efbcb5b25327b294..ad58f1867ebaba5d1b80b7c5d420514120511af6 100644 (file)
@@ -4,7 +4,7 @@ import decimal
 import datetime, os, re
 from sqlalchemy import *
 from sqlalchemy import exc, types, util, schema
-from sqlalchemy.sql import operators, column
+from sqlalchemy.sql import operators, column, table
 from sqlalchemy.test.testing import eq_
 import sqlalchemy.engine.url as url
 from sqlalchemy.databases import *
@@ -687,10 +687,10 @@ class BinaryTest(TestBase, AssertsExecutionResults):
         f = os.path.join(os.path.dirname(__file__), "..", name)
         return open(f, mode='rb').read()
 
-class ExpressionTest(TestBase, AssertsExecutionResults):
+class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     @classmethod
     def setup_class(cls):
-        global test_table, meta, MyCustomType
+        global test_table, meta, MyCustomType, MyTypeDec
 
         class MyCustomType(types.UserDefinedType):
             def get_col_spec(self):
@@ -705,13 +705,24 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
                 return process
             def adapt_operator(self, op):
                 return {operators.add:operators.sub, operators.sub:operators.add}.get(op, op)
+        
+        class MyTypeDec(types.TypeDecorator):
+            impl = String
+            
+            def process_bind_param(self, value, dialect):
+                return "BIND_IN" + str(value)
 
+            def process_result_value(self, value, dialect):
+                return value + "BIND_OUT"
+            
         meta = MetaData(testing.db)
         test_table = Table('test', meta,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)),
             Column('atimestamp', Date),
-            Column('avalue', MyCustomType))
+            Column('avalue', MyCustomType),
+            Column('bvalue', MyTypeDec),
+            )
 
         meta.create_all()
 
@@ -719,7 +730,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
                                         'id':1, 
                                         'data':'somedata', 
                                         'atimestamp':datetime.date(2007, 10, 15), 
-                                        'avalue':25})
+                                        'avalue':25, 'bvalue':'foo'})
 
     @classmethod
     def teardown_class(cls):
@@ -730,7 +741,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
 
         eq_(
             test_table.select().execute().fetchall(),
-            [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+            [(1, 'somedata', datetime.date(2007, 10, 15), 25, "BIND_INfooBIND_OUT")]
         )
 
     def test_bind_adapt(self):
@@ -740,17 +751,26 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
 
         eq_(
             testing.db.execute(
-                            test_table.select().where(expr), 
+                            select([test_table.c.id, test_table.c.data, test_table.c.atimestamp])
+                            .where(expr), 
                             {"thedate":datetime.date(2007, 10, 15)}).fetchall(),
-            [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+            [(1, 'somedata', datetime.date(2007, 10, 15))]
         )
 
         expr = test_table.c.avalue == bindparam("somevalue")
         eq_(expr.right.type._type_affinity, MyCustomType)
-        
+
         eq_(
             testing.db.execute(test_table.select().where(expr), {"somevalue":25}).fetchall(),
-            [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+            [(1, 'somedata', datetime.date(2007, 10, 15), 25, 'BIND_INfooBIND_OUT')]
+        )
+
+        expr = test_table.c.bvalue == bindparam("somevalue")
+        eq_(expr.right.type._type_affinity, String)
+        
+        eq_(
+            testing.db.execute(test_table.select().where(expr), {"somevalue":"foo"}).fetchall(),
+            [(1, 'somedata', datetime.date(2007, 10, 15), 25, 'BIND_INfooBIND_OUT')]
         )
     
     def test_literal_adapt(self):
@@ -799,7 +819,43 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
         # this one relies upon anonymous labeling to assemble result
         # processing rules on the column.
         assert testing.db.execute(select([expr])).scalar() == -15
-    
+
+    def test_typedec_operator_adapt(self):
+        expr = test_table.c.bvalue + "hi"
+        
+        assert expr.type.__class__ is String
+
+        eq_(
+            testing.db.execute(select([expr.label('foo')])).scalar(),
+            "BIND_INfooBIND_INhiBIND_OUT"
+        )
+
+    def test_typedec_righthand_coercion(self):
+        class MyTypeDec(types.TypeDecorator):
+            impl = String
+            
+            def process_bind_param(self, value, dialect):
+                return "BIND_IN" + str(value)
+
+            def process_result_value(self, value, dialect):
+                return value + "BIND_OUT"
+
+        tab = table('test', column('bvalue', MyTypeDec))
+        expr = tab.c.bvalue + 6
+        
+        self.assert_compile(
+            expr,
+            "test.bvalue || :bvalue_1",
+            use_default_dialect=True
+        )
+        
+        assert expr.type.__class__ is String
+        eq_(
+            testing.db.execute(select([expr.label('foo')])).scalar(),
+            "BIND_INfooBIND_IN6BIND_OUT"
+        )
+        
+        
     def test_bind_typing(self):
         from sqlalchemy.sql import column