]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The type/expression system now does a more complete job
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2010 19:33:06 +0000 (19:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2010 19:33:06 +0000 (19:33 +0000)
of determining the return type from an expression
as well as the adaptation of the Python operator into
a SQL operator, based on the full left/right/operator
of the given expression.  In particular
the date/time/interval system created for Postgresql
EXTRACT in [ticket:1647] has now been generalized into
the type system.   The previous behavior which often
occured of an expression "column + literal" forcing
the type of "literal" to be the same as that of "column"
will now usually not occur - the type of
"literal" is first derived from the Python type of the
literal, assuming standard native Python types + date
types, before falling back to that of the known type
on the other side of the expression.  Also part
of [ticket:1683].

CHANGES
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/types.py
test/dialect/test_postgresql.py
test/sql/test_select.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 17737885a0b3ab96471814a16351d92a0f28f006..ebbd60b38f68efb0f57aa467688a84090faf4e09 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -55,6 +55,23 @@ CHANGES
     
   - Restored the keys() method to ResultProxy.
   
+  - The type/expression system now does a more complete job
+    of determining the return type from an expression
+    as well as the adaptation of the Python operator into 
+    a SQL operator, based on the full left/right/operator
+    of the given expression.  In particular
+    the date/time/interval system created for Postgresql
+    EXTRACT in [ticket:1647] has now been generalized into
+    the type system.   The previous behavior which often
+    occured of an expression "column + literal" forcing
+    the type of "literal" to be the same as that of "column"
+    will now usually not occur - the type of 
+    "literal" is first derived from the Python type of the 
+    literal, assuming standard native Python types + date 
+    types, before falling back to that of the known type
+    on the other side of the expression.  Also part
+    of [ticket:1683].
+    
 - mysql
   - Fixed reflection bug whereby when COLLATE was present, 
     nullable flag and server defaults would not be reflected.
index cfbef69e89e3f1d1e695492926342e24e62733a8..7d4cbbbd808b503ff840c74757595c1ee9874fc8 100644 (file)
@@ -349,9 +349,16 @@ class PGCompiler(compiler.SQLCompiler):
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
-        affinity = sql_util.determine_date_affinity(extract.expr)
-
-        casts = {sqltypes.Date:'date', sqltypes.DateTime:'timestamp', sqltypes.Interval:'interval', sqltypes.Time:'time'}
+        if extract.expr.type:
+            affinity = extract.expr.type._type_affinity
+        else:
+            affinity = None
+        
+        casts = {
+                    sqltypes.Date:'date', 
+                    sqltypes.DateTime:'timestamp', 
+                    sqltypes.Interval:'interval', sqltypes.Time:'time'
+                }
         cast = casts.get(affinity, None)
         if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
             expr = extract.expr.op('::')(sql.literal_column(cast))
index 878b0d82653a44b475b4caf7779a36fa74c083f7..1ae706999ec6479a259143452fc3ea78889f1873 100644 (file)
@@ -407,14 +407,7 @@ def between(ctest, cleft, cright):
 
     """
     ctest = _literal_as_binds(ctest)
-    return _BinaryExpression(
-        ctest,
-        ClauseList(
-            _literal_as_binds(cleft, type_=ctest.type),
-            _literal_as_binds(cright, type_=ctest.type),
-            operator=operators.and_,
-            group=False),
-        operators.between_op)
+    return ctest.between(cleft, cright)
 
 
 def case(whens, value=None, else_=None):
@@ -1453,19 +1446,35 @@ class _CompareMixin(ColumnOperators):
             obj = self._check_literal(obj)
 
         if reverse:
-            return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+            return _BinaryExpression(obj, 
+                            self, 
+                            op, 
+                            type_=sqltypes.BOOLEANTYPE, 
+                            negate=negate, modifiers=kwargs)
         else:
-            return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+            return _BinaryExpression(self, 
+                            obj, 
+                            op, 
+                            type_=sqltypes.BOOLEANTYPE, 
+                            negate=negate, modifiers=kwargs)
 
     def __operate(self, op, obj, reverse=False):
         obj = self._check_literal(obj)
-
-        type_ = self._compare_type(obj)
-
+        
         if reverse:
-            return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_)
+            left, right = obj, self
         else:
-            return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_)
+            left, right = self, obj
+        
+        if left.type is None:
+            op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type)
+        elif right.type is None:
+            op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE)
+        else:    
+            op, result_type = left.type._adapt_expression(op, right.type)
+        
+        return _BinaryExpression(left, right, op, type_=result_type)
+        
 
     # a mapping of operators with the method they use, along with their negated
     # operator for comparison operators
@@ -1643,7 +1652,7 @@ class _CompareMixin(ColumnOperators):
         return lambda other: self.__operate(operator, other)
 
     def _bind_param(self, obj):
-        return _BindParamClause(None, obj, type_=self.type, unique=True)
+        return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
 
     def _check_literal(self, other):
         if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
@@ -1658,14 +1667,6 @@ class _CompareMixin(ColumnOperators):
         else:
             return other
 
-    def _compare_type(self, obj):
-        """Allow subclasses to override the type used in constructing
-        :class:`_BinaryExpression` objects.
-
-        Default return value is the type of the given object.
-
-        """
-        return obj.type
 
 class ColumnElement(ClauseElement, _CompareMixin):
     """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement.
@@ -2105,7 +2106,9 @@ class _BindParamClause(ColumnElement):
     __visit_name__ = 'bindparam'
     quote = None
 
-    def __init__(self, key, value, type_=None, unique=False, isoutparam=False, required=False):
+    def __init__(self, key, value, type_=None, unique=False, 
+                            isoutparam=False, required=False, 
+                            _fallback_type=None):
         """Construct a _BindParamClause.
 
         key
@@ -2151,12 +2154,12 @@ class _BindParamClause(ColumnElement):
         self.required = required
         
         if type_ is None:
-            self.type = sqltypes.type_map.get(type(value), sqltypes.NullType)()
+            self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE)
         elif isinstance(type_, type):
             self.type = type_()
         else:
             self.type = type_
-
+            
     def _clone(self):
         c = ClauseElement._clone(self)
         if self.unique:
@@ -2171,12 +2174,6 @@ class _BindParamClause(ColumnElement):
     def bind_processor(self, dialect):
         return self.type.dialect_impl(dialect).bind_processor(dialect)
 
-    def _compare_type(self, obj):
-        if not isinstance(self.type, sqltypes.NullType):
-            return self.type
-        else:
-            return obj.type
-
     def compare(self, other, **kw):
         """Compare this :class:`_BindParamClause` to the given clause."""
         
@@ -2342,7 +2339,14 @@ class ClauseList(ClauseElement):
             self.clauses = [
                 _literal_as_text(clause)
                 for clause in clauses if clause is not None]
-
+    
+    @util.memoized_property
+    def type(self):
+        if self.clauses:
+            return self.clauses[0].type
+        else:
+            return sqltypes.NULLTYPE
+            
     def __iter__(self):
         return iter(self.clauses)
 
@@ -2419,7 +2423,7 @@ class _Tuple(ClauseList, ColumnElement):
 
     def _bind_param(self, obj):
         return _Tuple(*[
-            _BindParamClause(None, o, type_=self.type, unique=True)
+            _BindParamClause(None, o, _fallback_type=self.type, unique=True)
             for o in obj
         ]).self_group()
     
@@ -2518,11 +2522,8 @@ class FunctionElement(ColumnElement, FromClause):
     def execute(self):
         return select([self]).execute()
 
-    def _compare_type(self, obj):
-        return self.type
-
     def _bind_param(self, obj):
-        return _BindParamClause(None, obj, type_=self.type, unique=True)
+        return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
 
     
 class Function(FunctionElement):
@@ -2539,7 +2540,7 @@ class Function(FunctionElement):
         FunctionElement.__init__(self, *clauses, **kw)
 
     def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, type_=self.type, unique=True)
+        return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
 
 
 class _Cast(ColumnElement):
@@ -2698,7 +2699,7 @@ class _BinaryExpression(ColumnElement):
                 self.right,
                 self.negate,
                 negate=self.operator,
-                type_=self.type,
+                type_=sqltypes.BOOLEANTYPE,
                 modifiers=self.modifiers)
         else:
             return super(_BinaryExpression, self)._negate()
@@ -3149,7 +3150,7 @@ class ColumnClause(_Immutable, ColumnElement):
             return []
 
     def _bind_param(self, obj):
-        return _BindParamClause(self.name, obj, type_=self.type, unique=True)
+        return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
 
     def _make_proxy(self, selectable, name=None, attach=True):
         # propagate the "is_literal" flag only if we are keeping our name,
@@ -3166,9 +3167,6 @@ class ColumnClause(_Immutable, ColumnElement):
             selectable.columns[c.name] = c
         return c
 
-    def _compare_type(self, obj):
-        return self.type
-
 class TableClause(_Immutable, FromClause):
     """Represents a "table" construct.
 
index 821b3a3d129e25cf9ef28c3f5fedb757a7429651..43673eaec5601aebb182d6da2a37e033be9f98e0 100644 (file)
@@ -46,92 +46,6 @@ def find_join_source(clauses, join_to):
     else:
         return None, None
 
-_date_affinities = None
-def determine_date_affinity(expr):
-    """Given an expression, determine if it returns 'interval', 'date', or 'datetime'.
-    
-    the PG dialect uses this to generate the extract() function.
-    
-    It's less than ideal since it basically needs to duplicate PG's 
-    date arithmetic rules.   
-    
-    Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html.
-    
-    Returns None if operators other than + or - are detected as well as types
-    outside of those above.
-    
-    """
-    
-    global _date_affinities
-    if _date_affinities is None:
-        Date, DateTime, Integer, \
-            Numeric, Interval, Time = \
-                                    sqltypes.Date, sqltypes.DateTime,\
-                                    sqltypes.Integer, sqltypes.Numeric,\
-                                    sqltypes.Interval, sqltypes.Time
-
-        _date_affinities = {
-            operators.add:{
-                (Date, Integer):Date,
-                (Date, Interval):DateTime,
-                (Date, Time):DateTime,
-                (Interval, Interval):Interval,
-                (DateTime, Interval):DateTime,
-                (Interval, Time):Time,
-            },
-            operators.sub:{
-                (Date, Integer):Date,
-                (Date, Interval):DateTime,
-                (Time, Time):Interval,
-                (Time, Interval):Time,
-                (DateTime, Interval):DateTime,
-                (Interval, Interval):Interval,
-                (DateTime, DateTime):Interval,
-            },
-            operators.mul:{
-                (Integer, Interval):Interval,
-                (Interval, Numeric):Interval,
-            },
-            operators.div: {
-                (Interval, Numeric):Interval
-            }
-        }
-    
-    if isinstance(expr, expression._BinaryExpression):
-        if expr.operator not in _date_affinities:
-            return None
-            
-        left_affin, right_affin = \
-            determine_date_affinity(expr.left), \
-            determine_date_affinity(expr.right)
-
-        if left_affin is None or right_affin is None:
-            return None
-        
-        if operators.is_commutative(expr.operator):
-            key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__))
-        else:
-            key = (left_affin, right_affin)
-        
-        lookup = _date_affinities[expr.operator]
-        return lookup.get(key, None)
-
-    # work around the fact that expressions put the wrong type
-    # on generated bind params when its "datetime + timedelta"
-    # and similar
-    if isinstance(expr, expression._BindParamClause):
-        type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)()
-    else:
-        type_ = expr.type
-
-    affinities = set([sqltypes.Date, sqltypes.DateTime, 
-                    sqltypes.Interval, sqltypes.Time, sqltypes.Integer])
-        
-    if type_ is not None and type_._type_affinity in affinities:
-        return type_._type_affinity
-    else:
-        return None
-    
     
     
 def find_tables(clause, check_columns=False, 
index f4d94c918088bd3a87a8cde472ccb93b8f6df62c..465454df95bc2d60d22580b2fbf6463ba1d4abaa 100644 (file)
@@ -26,12 +26,13 @@ from decimal import Decimal as _python_Decimal
 import codecs
 
 from sqlalchemy import exc, schema
-from sqlalchemy.sql import expression
+from sqlalchemy.sql import expression, operators
 import sys
 schema.types = expression.sqltypes =sys.modules['sqlalchemy.types']
 from sqlalchemy.util import pickle
 from sqlalchemy.sql.visitors import Visitable
 from sqlalchemy import util
+
 NoneType = type(None)
 if util.jython:
     import array
@@ -95,22 +96,23 @@ class AbstractType(Visitable):
         """
         return None
 
-    def adapt_operator(self, op):
-        """Given an operator from the sqlalchemy.sql.operators package,
-        translate it to a new operator based on the semantics of this type.
-
-        By default, returns the operator unchanged.
-
+    def _adapt_expression(self, op, othertype):
+        """evaluate the return type of <self> <op> <othertype>,
+        and apply any adaptations to the given operator.
+        
         """
-        return op
+        return op, self
         
     @util.memoized_property
     def _type_affinity(self):
         """Return a rudimental 'affinity' value expressing the general class of type."""
-        
-        for i, t in enumerate(self.__class__.__mro__):
+
+        typ = None
+        for t in self.__class__.__mro__:
             if t is TypeEngine or t is UserDefinedType:
-                return self.__class__.__mro__[i - 1]
+                return typ
+            elif issubclass(t, TypeEngine):
+                typ = t
         else:
             return self.__class__
         
@@ -206,6 +208,23 @@ class UserDefinedType(TypeEngine):
     """
     __visit_name__ = "user_defined"
 
+    def _adapt_expression(self, op, othertype):
+        """evaluate the return type of <self> <op> <othertype>,
+        and apply any adaptations to the given operator.
+        
+        """
+        return self.adapt_operator(op), self
+
+    def adapt_operator(self, op):
+        """A hook which allows the given operator to be adapted
+        to something new. 
+        
+        See also UserDefinedType._adapt_expression(), an as-yet-
+        semi-public method with greater capability in this regard.
+        
+        """
+        return op
+
 class TypeDecorator(AbstractType):
     """Allows the creation of types which add additional functionality
     to an existing type.
@@ -429,18 +448,41 @@ class NullType(TypeEngine):
     """
     __visit_name__ = 'null'
 
+    def _adapt_expression(self, op, othertype):
+        if othertype is NullType or not operators.is_commutative(op):
+            return op, self
+        else:
+            return othertype._adapt_expression(op, self)
+
 NullTypeEngine = NullType
 
 class Concatenable(object):
     """A mixin that marks a type as supporting 'concatenation', typically strings."""
 
-    def adapt_operator(self, op):
-        """Converts an add operator to concat."""
-        from sqlalchemy.sql import operators
-        if op is operators.add:
-            return operators.concat_op
+    def _adapt_expression(self, op, othertype):
+        if op is operators.add and isinstance(othertype, (Concatenable, NullType)):
+            return operators.concat_op, self
         else:
-            return op
+            return op, self
+
+class _DateAffinity(object):
+    """Mixin date/time specific expression adaptations.
+    
+    Rules are implemented within Date,Time,Interval,DateTime, Numeric, Integer.
+    Based on http://www.postgresql.org/docs/current/static/functions-datetime.html.
+    
+    """
+    
+    @property
+    def _expression_adaptations(self):
+        raise NotImplementedError()
+
+    _blank_dict = util.frozendict()
+    def _adapt_expression(self, op, othertype):
+        othertype = othertype._type_affinity
+        return op, \
+                self._expression_adaptations.get(op, self._blank_dict).\
+                get(othertype, NULLTYPE)
 
 class String(Concatenable, TypeEngine):
     """The base for all string and character types.
@@ -673,14 +715,24 @@ class UnicodeText(Text):
         super(UnicodeText, self).__init__(length=length, **kwargs)
 
 
-class Integer(TypeEngine):
+class Integer(_DateAffinity, TypeEngine):
     """A type for ``int`` integers."""
 
     __visit_name__ = 'integer'
 
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
-
+    
+    @util.memoized_property
+    def _expression_adaptations(self):
+        return {
+            operators.add:{
+                Date:Date,
+            },
+            operators.mul:{
+                Interval:Interval
+            },
+        }
 
 class SmallInteger(Integer):
     """A type for smaller ``int`` integers.
@@ -702,7 +754,7 @@ class BigInteger(Integer):
 
     __visit_name__ = 'big_integer'
 
-class Numeric(TypeEngine):
+class Numeric(_DateAffinity, TypeEngine):
     """A type for fixed precision numbers.
 
     Typically generates DECIMAL or NUMERIC.  Returns
@@ -776,6 +828,14 @@ class Numeric(TypeEngine):
         else:
             return None
 
+    @util.memoized_property
+    def _expression_adaptations(self):
+        return {
+            operators.mul:{
+                Interval:Interval
+            },
+        }
+
 
 class Float(Numeric):
     """A type for ``float`` numbers.  
@@ -804,7 +864,7 @@ class Float(Numeric):
         return impltype(precision=self.precision, asdecimal=self.asdecimal)
 
 
-class DateTime(TypeEngine):
+class DateTime(_DateAffinity, TypeEngine):
     """A type for ``datetime.datetime()`` objects.
 
     Date and time types return objects from the Python ``datetime``
@@ -826,8 +886,20 @@ class DateTime(TypeEngine):
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
+    @util.memoized_property
+    def _expression_adaptations(self):
+        return {
+            operators.add:{
+                Interval:DateTime,
+            },
+            operators.sub:{
+                Interval:DateTime,
+                DateTime:Interval,
+            },
+        }
+        
 
-class Date(TypeEngine):
+class Date(_DateAffinity,TypeEngine):
     """A type for ``datetime.date()`` objects."""
 
     __visit_name__ = 'date'
@@ -835,8 +907,32 @@ class Date(TypeEngine):
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
-
-class Time(TypeEngine):
+    @util.memoized_property
+    def _expression_adaptations(self):
+        return {
+            operators.add:{
+                Integer:Date,
+                Interval:DateTime,
+                Time:DateTime,
+            },
+            operators.sub:{
+                # date - integer = date
+                Integer:Date,
+                
+                # date - date = integer.
+                Date:Integer,
+
+                Interval:DateTime,
+                
+                # date - datetime = interval,
+                # this one is not in the PG docs 
+                # but works
+                DateTime:Interval,
+            },
+        }
+
+
+class Time(_DateAffinity,TypeEngine):
     """A type for ``datetime.time()`` objects."""
 
     __visit_name__ = 'time'
@@ -850,6 +946,20 @@ class Time(TypeEngine):
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
+    @util.memoized_property
+    def _expression_adaptations(self):
+        return {
+            operators.add:{
+                Date:DateTime,
+                Interval:Time
+            },
+            operators.sub:{
+                Time:Interval,
+                Interval:Time,
+            },
+        }
+
+
 class _Binary(TypeEngine):
     """Define base behavior for binary types."""
 
@@ -1245,7 +1355,7 @@ class Boolean(TypeEngine, SchemaType):
                 return value and True or False
             return process
 
-class Interval(TypeDecorator):
+class Interval(_DateAffinity, TypeDecorator):
     """A type for ``datetime.timedelta()`` objects.
 
     The Interval type deals with ``datetime.timedelta`` objects.  In
@@ -1319,10 +1429,31 @@ class Interval(TypeDecorator):
                 return value - epoch
         return process
 
+    @util.memoized_property
+    def _expression_adaptations(self):
+        return {
+            operators.add:{
+                Date:DateTime,
+                Interval:Interval,
+                DateTime:DateTime,
+                Time:Time,
+            },
+            operators.sub:{
+                Interval:Interval
+            },
+            operators.mul:{
+                Numeric:Interval
+            },
+            operators.div: {
+                Numeric:Interval
+            }
+        }
+
     @property
     def _type_affinity(self):
         return Interval
 
+
 class FLOAT(Float):
     """The SQL FLOAT type."""
 
@@ -1440,22 +1571,23 @@ class BOOLEAN(Boolean):
     __visit_name__ = 'BOOLEAN'
 
 NULLTYPE = NullType()
+BOOLEANTYPE = Boolean()
 
 # using VARCHAR/NCHAR so that we dont get the genericized "String"
 # type which usually resolves to TEXT/CLOB
 type_map = {
-    str: String,
+    str: String(),
     # Py2K
-    unicode : String,
+    unicode : String(),
     # end Py2K
-    int : Integer,
-    float : Numeric,
-    bool: Boolean,
-    _python_Decimal : Numeric,
-    dt.date : Date,
-    dt.datetime : DateTime,
-    dt.time : Time,
-    dt.timedelta : Interval,
-    NoneType: NullType
+    int : Integer(),
+    float : Numeric(),
+    bool: BOOLEANTYPE,
+    _python_Decimal : Numeric(),
+    dt.date : Date(),
+    dt.datetime : DateTime(),
+    dt.time : Time(),
+    dt.timedelta : Interval(),
+    NoneType: NULLTYPE
 }
 
index 952f633ae3340e1220f414b73e6e0a94ef8a491a..fbbc394c9dda775818c2e57f6c736fe760c8c2e7 100644 (file)
@@ -113,6 +113,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
 
         for field in 'year', 'month', 'day', 'epoch', 'hour':
             for expr, compiled_expr in [
+
                 ( t.c.col1, "t.col1 :: timestamp" ),
                 ( t.c.col2, "t.col2 :: date" ),
                 ( t.c.col3, "t.col3 :: time" ),
index 766ce8e9b789af479df037ca8f2dccd9c7effd63..657509d6591697130511efccf473ddbc61dd9e5b 100644 (file)
@@ -645,7 +645,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
             (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.dialect()),            
         ]:
             self.assert_compile(expr, check, dialect=dialect)
-        
+
     def test_composed_string_comparators(self):
         self.assert_compile(
             table1.c.name.contains('jo'), "mytable.name LIKE '%%' || :name_1 || '%%'" , checkparams = {'name_1': u'jo'},
@@ -1377,7 +1377,7 @@ EXISTS (select yay from foo where boo = lar)",
         
         assert [str(c) for c in s.c] == ["id", "hoho"]
 
-        
+    
     @testing.emits_warning('.*empty sequence.*')
     def test_in(self):
         self.assert_compile(table1.c.myid.in_(['a']),
@@ -1547,7 +1547,7 @@ EXISTS (select yay from foo where boo = lar)",
             "SELECT dt.date FROM dt WHERE dt.date BETWEEN :date_1 AND :date_2", checkparams={'date_1':datetime.date(2006,6,1), 'date_2':datetime.date(2006,6,5)})
 
         self.assert_compile(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))),
-            "SELECT dt.date FROM dt WHERE dt.date BETWEEN :param_1 AND :param_2", checkparams={'param_1':datetime.date(2006,6,1), 'param_2':datetime.date(2006,6,5)})
+            "SELECT dt.date FROM dt WHERE dt.date BETWEEN :date_1 AND :date_2", checkparams={'date_1':datetime.date(2006,6,1), 'date_2':datetime.date(2006,6,5)})
 
     def test_operator_precedence(self):
         table = Table('op', metadata,
index 0dcf4470fb485473091b8430583c05c8585f6d45..fd957c6e517b4b6887bed58c9afb5957f4f57857 100644 (file)
@@ -57,6 +57,14 @@ class AdaptTest(TestBase):
             
 class TypeAffinityTest(TestBase):
     def test_type_affinity(self):
+        for type_, affin in [
+            (String(), String),
+            (VARCHAR(), String),
+            (Date(), Date),
+            (LargeBinary(), types._Binary)
+        ]:
+            eq_(type_._type_affinity, affin)
+            
         for t1, t2, comp in [
             (Integer(), SmallInteger(), True),
             (Integer(), String(), False),
@@ -536,7 +544,7 @@ class BinaryTest(TestBase, AssertsExecutionResults):
 class ExpressionTest(TestBase, AssertsExecutionResults):
     @classmethod
     def setup_class(cls):
-        global test_table, meta
+        global test_table, meta, MyCustomType
 
         class MyCustomType(types.UserDefinedType):
             def get_col_spec(self):
@@ -570,7 +578,10 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
     def test_control(self):
         assert testing.db.execute("select avalue from test").scalar() == 250
 
-        assert test_table.select().execute().fetchall() == [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+        eq_(
+            test_table.select().execute().fetchall(),
+            [(1, 'somedata', datetime.date(2007, 10, 15), 25)]
+        )
 
     def test_bind_adapt(self):
         expr = test_table.c.atimestamp == bindparam("thedate")
@@ -597,6 +608,12 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
         expr = test_table.c.avalue + 40
         assert expr.type.__class__ is test_table.c.avalue.type.__class__
 
+        # value here is calculated as (250 - 40) / 10 = 21
+        # because "40" is an integer, not an "avalue"
+        assert testing.db.execute(select([expr.label('foo')])).scalar() == 21
+
+        expr = test_table.c.avalue + literal(40, type_=MyCustomType)
+        
         # + operator converted to -
         # value is calculated as: (250 - (40 * 10)) / 10 == -15
         assert testing.db.execute(select([expr.label('foo')])).scalar() == -15
@@ -604,6 +621,44 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
         # this one relies upon anonymous labeling to assemble result
         # processing rules on the column.
         assert testing.db.execute(select([expr])).scalar() == -15
+    
+    def test_bind_typing(self):
+        from sqlalchemy.sql import column
+        
+        class MyFoobarType(types.UserDefinedType):
+            pass
+        
+        class Foo(object):
+            pass
+        
+        # unknown type + integer, right hand bind
+        # is an Integer
+        expr = column("foo", MyFoobarType) + 5
+        assert expr.right.type._type_affinity is types.Integer
+
+        # unknown type + unknown, right hand bind
+        # coerces to the left
+        expr = column("foo", MyFoobarType) + Foo()
+        assert expr.right.type._type_affinity is MyFoobarType
+        
+        # including for non-commutative ops
+        expr = column("foo", MyFoobarType) - Foo()
+        assert expr.right.type._type_affinity is MyFoobarType
+
+        expr = column("foo", MyFoobarType) - datetime.date(2010, 8, 25)
+        assert expr.right.type._type_affinity is types.Date
+        
+    def test_date_coercion(self):
+        from sqlalchemy.sql import column
+        
+        expr = column('bar', types.NULLTYPE) - column('foo', types.TIMESTAMP)
+        eq_(expr.type._type_affinity, types.NullType)
+        
+        expr = func.sysdate() - column('foo', types.TIMESTAMP)
+        eq_(expr.type._type_affinity, types.Interval)
+
+        expr = func.current_date() - column('foo', types.TIMESTAMP)
+        eq_(expr.type._type_affinity, types.Interval)
         
     def test_distinct(self):
         s = select([distinct(test_table.c.avalue)])