]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Return given type when it matches the adaptation
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Jun 2017 16:44:15 +0000 (12:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Jun 2017 16:46:53 +0000 (12:46 -0400)
The rules for type coercion between :class:`.Numeric`, :class:`.Integer`,
and date-related types now include additional logic that will attempt
to preserve the settings of the incoming type on the "resolved" type.
Currently the target for this is the ``asdecimal`` flag, so that
a math operation between :class:`.Numeric` or :class:`.Float` and
:class:`.Integer` will preserve the "asdecimal" flag as well as
if the type should be the :class:`.Float` subclass.

Change-Id: Idfaba17220d6db21ca1ca4dcb4c19834cd397817
Fixes: #4018
doc/build/changelog/changelog_12.rst
doc/build/changelog/migration_12.rst
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_types.py

index 5dc83da2d50898e5e3dad32eeb1d943dcdb745ec..043f5f1ac287f9cf33fba39ee9562d2639b9e87c 100644 (file)
 .. changelog::
     :version: 1.2.0b1
 
+    .. change:: 4018
+        :tags: bug, sql
+        :tickets: 4018
+
+        The rules for type coercion between :class:`.Numeric`, :class:`.Integer`,
+        and date-related types now include additional logic that will attempt
+        to preserve the settings of the incoming type on the "resolved" type.
+        Currently the target for this is the ``asdecimal`` flag, so that
+        a math operation between :class:`.Numeric` or :class:`.Float` and
+        :class:`.Integer` will preserve the "asdecimal" flag as well as
+        if the type should be the :class:`.Float` subclass.
+
+        .. seealso::
+
+            :ref:`change_floats_12`
+
     .. change:: 4017
         :tags: bug, sql
         :tickets: 4017
index add12a50c304d73ad3c3ea30fe63b9c2aced95ec..b3f7d726175e584a5af2e092c34163c583955b03 100644 (file)
@@ -782,10 +782,25 @@ if the application is working with plain floats.
   meant the result type would coerce to ``Decimal()``.  In particular,
   this would emit a confusing warning on SQLite::
 
-        float_value = connection.scalar(
-            select([literal(4.56)])   # the "BindParameter" will now be
-                                      # Float, not Numeric(asdecimal=True)
-        )
+
+    float_value = connection.scalar(
+        select([literal(4.56)])   # the "BindParameter" will now be
+                                  # Float, not Numeric(asdecimal=True)
+    )
+
+* Math operations between :class:`.Numeric`, :class:`.Float`, and
+  :class:`.Integer` will now preserve the :class:`.Numeric` or :class:`.Float`
+  type in the resulting expression's type, including the ``asdecimal`` flag
+  as well as if the type should be :class:`.Float`::
+
+    # asdecimal flag is maintained
+    expr = column('a', Integer) * column('b', Numeric(asdecimal=False))
+    assert expr.type.asdecimal == False
+
+    # Float subclass of Numeric is maintained
+    expr = column('a', Integer) * column('b', Float())
+    assert isinstance(expr.type, Float)
+
 
 :ticket:`4017`
 
index 06b5e5c19c02b545d84c5d3cd61757bc88cd1e57..5b53f390eddc1b2dcb1fc9c1e0d9cae6043584f7 100644 (file)
@@ -31,13 +31,12 @@ if util.jython:
     import array
 
 
-class _DateAffinity(object):
+class _LookupExpressionAdapter(object):
 
-    """Mixin date/time specific expression adaptations.
+    """Mixin expression adaptations based on lookup tables.
 
-    Rules are implemented within Date,Time,Interval,DateTime, Numeric,
-    Integer. Based on http://www.postgresql.org/docs/current/static
-    /functions-datetime.html.
+    These rules are currenly used by the numeric, integer and date types
+    which have detailed cross-expression coercion rules.
 
     """
 
@@ -50,12 +49,15 @@ class _DateAffinity(object):
 
         def _adapt_expression(self, op, other_comparator):
             othertype = other_comparator.type._type_affinity
-            return (
-                op, to_instance(
-                    self.type._expression_adaptations.
-                    get(op, self._blank_dict).
-                    get(othertype, NULLTYPE))
-            )
+            lookup = self.type._expression_adaptations.get(
+                op, self._blank_dict).get(
+                othertype, NULLTYPE)
+            if lookup is othertype:
+                return (op, other_comparator.type)
+            elif lookup is self.type._type_affinity:
+                return (op, self.type)
+            else:
+                return (op, to_instance(lookup))
     comparator_factory = Comparator
 
 
@@ -384,7 +386,7 @@ class UnicodeText(Text):
         super(UnicodeText, self).__init__(length=length, **kwargs)
 
 
-class Integer(_DateAffinity, TypeEngine):
+class Integer(_LookupExpressionAdapter, TypeEngine):
 
     """A type for ``int`` integers."""
 
@@ -456,7 +458,7 @@ class BigInteger(Integer):
     __visit_name__ = 'big_integer'
 
 
-class Numeric(_DateAffinity, TypeEngine):
+class Numeric(_LookupExpressionAdapter, TypeEngine):
 
     """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``.
 
@@ -703,29 +705,8 @@ class Float(Numeric):
         else:
             return None
 
-    @util.memoized_property
-    def _expression_adaptations(self):
-        return {
-            operators.mul: {
-                Interval: Interval,
-                Numeric: self.__class__,
-            },
-            operators.div: {
-                Numeric: self.__class__,
-            },
-            operators.truediv: {
-                Numeric: self.__class__,
-            },
-            operators.add: {
-                Numeric: self.__class__,
-            },
-            operators.sub: {
-                Numeric: self.__class__,
-            }
-        }
-
 
-class DateTime(_DateAffinity, TypeEngine):
+class DateTime(_LookupExpressionAdapter, TypeEngine):
 
     """A type for ``datetime.datetime()`` objects.
 
@@ -770,6 +751,10 @@ class DateTime(_DateAffinity, TypeEngine):
 
     @util.memoized_property
     def _expression_adaptations(self):
+
+        # Based on http://www.postgresql.org/docs/current/\
+        # static/functions-datetime.html.
+
         return {
             operators.add: {
                 Interval: self.__class__,
@@ -781,7 +766,7 @@ class DateTime(_DateAffinity, TypeEngine):
         }
 
 
-class Date(_DateAffinity, TypeEngine):
+class Date(_LookupExpressionAdapter, TypeEngine):
 
     """A type for ``datetime.date()`` objects."""
 
@@ -796,6 +781,9 @@ class Date(_DateAffinity, TypeEngine):
 
     @util.memoized_property
     def _expression_adaptations(self):
+        # Based on http://www.postgresql.org/docs/current/\
+        # static/functions-datetime.html.
+
         return {
             operators.add: {
                 Integer: self.__class__,
@@ -819,7 +807,7 @@ class Date(_DateAffinity, TypeEngine):
         }
 
 
-class Time(_DateAffinity, TypeEngine):
+class Time(_LookupExpressionAdapter, TypeEngine):
 
     """A type for ``datetime.time()`` objects."""
 
@@ -837,6 +825,9 @@ class Time(_DateAffinity, TypeEngine):
 
     @util.memoized_property
     def _expression_adaptations(self):
+        # Based on http://www.postgresql.org/docs/current/\
+        # static/functions-datetime.html.
+
         return {
             operators.add: {
                 Date: DateTime,
@@ -1627,7 +1618,7 @@ class Boolean(TypeEngine, SchemaType):
             return processors.int_to_boolean
 
 
-class Interval(_DateAffinity, TypeDecorator):
+class Interval(_LookupExpressionAdapter, TypeDecorator):
 
     """A type for ``datetime.timedelta()`` objects.
 
@@ -1719,6 +1710,9 @@ class Interval(_DateAffinity, TypeDecorator):
 
     @util.memoized_property
     def _expression_adaptations(self):
+        # Based on http://www.postgresql.org/docs/current/\
+        # static/functions-datetime.html.
+
         return {
             operators.add: {
                 Date: DateTime,
index 9107adaca0b84d04157fe6714314eab21c57c51a..404d42c7aaa81c98fcb73fbbd527eb8c34361abc 100644 (file)
@@ -29,6 +29,8 @@ from sqlalchemy.testing.util import picklers
 from sqlalchemy.testing.util import round_decimal
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
+from sqlalchemy.sql import column
+import operator
 
 
 class AdaptTest(fixtures.TestBase):
@@ -2203,8 +2205,6 @@ class ExpressionTest(
         eq_(expr.type._type_affinity, types.Interval)
 
     def test_numerics_coercion(self):
-        from sqlalchemy.sql import column
-        import operator
 
         for op in (operator.add, operator.mul, operator.truediv, operator.sub):
             for other in (Numeric(10, 2), Integer):
@@ -2219,6 +2219,28 @@ class ExpressionTest(
                 )
                 assert isinstance(expr.type, types.Numeric)
 
+    def test_asdecimal_int_to_numeric(self):
+        expr = column('a', Integer) * column('b', Numeric(asdecimal=False))
+        is_(expr.type.asdecimal, False)
+
+        expr = column('a', Integer) * column('b', Numeric())
+        is_(expr.type.asdecimal, True)
+
+        expr = column('a', Integer) * column('b', Float())
+        is_(expr.type.asdecimal, False)
+        assert isinstance(expr.type, Float)
+
+    def test_asdecimal_numeric_to_int(self):
+        expr = column('a', Numeric(asdecimal=False)) * column('b', Integer)
+        is_(expr.type.asdecimal, False)
+
+        expr = column('a', Numeric()) * column('b', Integer)
+        is_(expr.type.asdecimal, True)
+
+        expr = column('a', Float()) * column('b', Integer)
+        is_(expr.type.asdecimal, False)
+        assert isinstance(expr.type, Float)
+
     def test_null_comparison(self):
         eq_(
             str(column('a', types.NullType()) + column('b', types.NullType())),