]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Return self when Variant.coerce_compared_value would return impl
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Apr 2017 15:36:16 +0000 (11:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Apr 2017 17:47:42 +0000 (13:47 -0400)
Fixed regression released in 1.1.5 due to :ticket:`3859` where
adjustments to the "right-hand-side" evaluation of an expression
based on :class:`.Variant` to honor the underlying type's
"right-hand-side" rules caused the :class:`.Variant` type
to be inappropriately lost, in those cases when we *do* want the
left-hand side type to be transferred directly to the right hand side
so that bind-level rules can be applied to the expression's argument.

Change-Id: Ia54dbbb19398549d654b74668753c4152599d900
Fixes: #3952
doc/build/changelog/changelog_11.rst
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

index 35e217b33416b794791d43af936f56a1fa6b3e41..e9774f9255c38d1cefb9705e8083b034445edbc6 100644 (file)
 .. changelog::
     :version: 1.1.9
 
+    .. change:: 3952
+        :tags: bug, sql
+        :versions: 1.2.0.b1
+        :tickets: 3952
+
+        Fixed regression released in 1.1.5 due to :ticket:`3859` where
+        adjustments to the "right-hand-side" evaluation of an expression
+        based on :class:`.Variant` to honor the underlying type's
+        "right-hand-side" rules caused the :class:`.Variant` type
+        to be inappropriately lost, in those cases when we *do* want the
+        left-hand side type to be transferred directly to the right hand side
+        so that bind-level rules can be applied to the expression's argument.
+
+
 .. changelog::
     :version: 1.1.8
     :released: March 31, 2017
index d537e49f023d31405c43d70cc12544e5453a491e..4b561a7058a8b4df1381e9787dc0fdcfb580cced 100644 (file)
@@ -1214,7 +1214,11 @@ class Variant(TypeDecorator):
         self.mapping = mapping
 
     def coerce_compared_value(self, operator, value):
-        return self.impl.coerce_compared_value(operator, value)
+        result = self.impl.coerce_compared_value(operator, value)
+        if result is self.impl:
+            return self
+        else:
+            return result
 
     def load_dialect_impl(self, dialect):
         if dialect.name in self.mapping:
index b417e696409d22c473e25e37ae1db6e08b13729b..f46ef21cd8b6571741a3d46f5970b9998e8eea6c 100644 (file)
@@ -1,5 +1,5 @@
 # coding: utf-8
-from sqlalchemy.testing import eq_, is_, assert_raises, \
+from sqlalchemy.testing import eq_, is_, is_not_, assert_raises, \
     assert_raises_message, expect_warnings
 import decimal
 import datetime
@@ -1006,6 +1006,55 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
             'fooUTWO'
         )
 
+    def test_comparator_variant(self):
+        expr = column('x', self.variant) == "bar"
+        is_(
+            expr.right.type, self.variant
+        )
+
+    @testing.only_on("sqlite")
+    @testing.provide_metadata
+    def test_round_trip(self):
+        variant = self.UTypeOne().with_variant(
+            self.UTypeTwo(), 'sqlite')
+
+        t = Table('t', self.metadata,
+                Column('x', variant)
+        )
+        with testing.db.connect() as conn:
+            t.create(conn)
+
+            conn.execute(
+                t.insert(),
+                x='foo'
+            )
+
+            eq_(
+                conn.scalar(select([t.c.x]).where(t.c.x == 'foo')),
+                'fooUTWO'
+            )
+
+    @testing.only_on("sqlite")
+    @testing.provide_metadata
+    def test_round_trip_sqlite_datetime(self):
+        variant = DateTime().with_variant(
+            dialects.sqlite.DATETIME(truncate_microseconds=True), 'sqlite')
+
+        t = Table('t', self.metadata,
+                Column('x', variant)
+        )
+        with testing.db.connect() as conn:
+            t.create(conn)
+
+            conn.execute(
+                t.insert(),
+                x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839)
+            )
+
+            eq_(
+                conn.scalar(select([t.c.x]).where(t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059))),
+                datetime.datetime(2015, 4, 18, 10, 15, 17)
+            )
 
 class UnicodeTest(fixtures.TestBase):
 
@@ -2051,7 +2100,7 @@ class ExpressionTest(
             "BIND_INfooBIND_IN6BIND_OUT"
         )
 
-    def test_variant_righthand_coercion(self):
+    def test_variant_righthand_coercion_honors_wrapped(self):
         my_json_normal = JSON()
         my_json_variant = JSON().with_variant(String(), "sqlite")
 
@@ -2063,10 +2112,31 @@ class ExpressionTest(
         expr = tab.c.avalue['foo'] == 'bar'
 
         is_(expr.right.type._type_affinity, String)
+        is_not_(expr.right.type, my_json_normal)
 
         expr = tab.c.bvalue['foo'] == 'bar'
 
         is_(expr.right.type._type_affinity, String)
+        is_not_(expr.right.type, my_json_variant)
+
+    def test_variant_righthand_coercion_returns_self(self):
+        my_datetime_normal = DateTime()
+        my_datetime_variant = DateTime().with_variant(
+            dialects.sqlite.DATETIME(truncate_microseconds=False), "sqlite")
+
+        tab = table(
+            'test',
+            column('avalue', my_datetime_normal),
+            column('bvalue', my_datetime_variant)
+        )
+        expr = tab.c.avalue == datetime.datetime(2015, 10, 14, 15, 17, 18)
+
+        is_(expr.right.type._type_affinity, DateTime)
+        is_(expr.right.type, my_datetime_normal)
+
+        expr = tab.c.bvalue == datetime.datetime(2015, 10, 14, 15, 17, 18)
+
+        is_(expr.right.type, my_datetime_variant)
 
     def test_bind_typing(self):
         from sqlalchemy.sql import column