]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- add a type_coerce() step within Enum, Boolean to the CHECK constraint,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Oct 2013 20:25:46 +0000 (16:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Oct 2013 20:25:46 +0000 (16:25 -0400)
so that the custom type isn't exposed to an operation that is against the
"impl" type's constraint, [ticket:2842]
- this change showed up as some recursion overflow in pickling with labels,
add a __reduce__() there....pickling of expressions is less and less something
that's very viable...

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_types.py

index 126f3a3b8e1234c4f85c6dd543b7e40faf85cb63..727931f9aa629b521d4f4b924dbdcc83348a3e3e 100644 (file)
 .. changelog::
     :version: 0.8.3
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 2842
+        :versions: 0.9.0
+
+        The :class:`.Enum` and :class:`.Boolean` types now bypass
+        any custom (e.g. TypeDecorator) type in use when producing the
+        CHECK constraint for the "non native" type.  This so that the custom type
+        isn't involved in the expression within the CHECK, since this
+        expression is against the "impl" value and not the "decorated" value.
+
     .. change::
         :tags: bug, postgresql
         :tickets: 2844
         Fixed bug in default compiler plus those of postgresql, mysql, and
         mssql to ensure that any literal SQL expression values are
         rendered directly as literals, instead of as bound parameters,
-        within a CREATE INDEX statement.
+        within a CREATE INDEX statement.  This also changes the rendering
+        scheme for other DDL such as constraints.
 
     .. change::
         :tags: bug, sql
index 2688ef103df55952cd2d0f555ecb3fd4e5d2aa41..f70496418f0f768bf7766782141a383572be1fc2 100644 (file)
@@ -1842,6 +1842,9 @@ class Label(ColumnElement):
         self._type = type_
         self._proxies = [element]
 
+    def __reduce__(self):
+        return self.__class__, (self.name, self._element, self._type)
+
     @util.memoized_property
     def _order_by_label_element(self):
         return self
index db0ad248c3c91aef5840a8d7e928ed8fa4ee7515..1d7dacb915820ce210b18273114de2abf86bcc44 100644 (file)
@@ -12,7 +12,7 @@ import datetime as dt
 import codecs
 
 from .type_api import TypeEngine, TypeDecorator, to_instance
-from .elements import quoted_name
+from .elements import quoted_name, type_coerce
 from .default_comparator import _DefaultColumnComparator
 from .. import exc, util, processors
 from .base import _bind_or_error, SchemaEventTarget
@@ -1059,7 +1059,7 @@ class Enum(String, SchemaType):
             SchemaType._set_table(self, column, table)
 
         e = schema.CheckConstraint(
-                        column.in_(self.enums),
+                        type_coerce(column, self).in_(self.enums),
                         name=self.name,
                         _create_rule=util.portable_instancemethod(
                                         self._should_create_constraint)
@@ -1196,7 +1196,7 @@ class Boolean(TypeEngine, SchemaType):
             return
 
         e = schema.CheckConstraint(
-                        column.in_([0, 1]),
+                        type_coerce(column, self).in_([0, 1]),
                         name=self.name,
                         _create_rule=util.portable_instancemethod(
                                     self._should_create_constraint)
index 2a22224a2bb719bbaabcf27959cda92161f9d744..d122aef6a8e1ae21a54a82d22df56f5b2fadb9e3 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import exc, types, util, dialects
 for name in dialects.__all__:
     __import__("sqlalchemy.dialects.%s" % name)
 from sqlalchemy.sql import operators, column, table
+from sqlalchemy.schema import CheckConstraint, AddConstraint
 from sqlalchemy.engine import default
 from sqlalchemy.testing.schema import Table, Column
 from sqlalchemy import testing
@@ -768,7 +769,7 @@ class UnicodeTest(fixtures.TestBase):
         )
 
 
-class EnumTest(fixtures.TestBase):
+class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
     @classmethod
     def setup_class(cls):
         global enum_table, non_native_enum_table, metadata
@@ -851,6 +852,42 @@ class EnumTest(fixtures.TestBase):
             {'id': 4, 'someenum': 'four'}
         )
 
+    def test_non_native_constraint_custom_type(self):
+        class Foob(object):
+            def __init__(self, name):
+                self.name = name
+
+        class MyEnum(types.SchemaType, TypeDecorator):
+            def __init__(self, values):
+                self.impl = Enum(
+                                *[v.name for v in values],
+                                name="myenum",
+                                native_enum=False
+                            )
+
+
+            def _set_table(self, table, column):
+                self.impl._set_table(table, column)
+
+            # future method
+            def process_literal_param(self, value, dialect):
+                return value.name
+
+            def process_bind_param(self, value, dialect):
+                return value.name
+
+        m = MetaData()
+        t1 = Table('t', m, Column('x', MyEnum([Foob('a'), Foob('b')])))
+        const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][0]
+
+        self.assert_compile(
+            AddConstraint(const),
+            "ALTER TABLE t ADD CONSTRAINT myenum CHECK (x IN ('a', 'b'))",
+            dialect="default"
+        )
+
+
+
     @testing.fails_on('mysql',
                     "the CHECK constraint doesn't raise an exception for unknown reason")
     def test_non_native_constraint(self):
@@ -1453,7 +1490,7 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(row['non_native_interval'], None)
 
 
-class BooleanTest(fixtures.TestBase, AssertsExecutionResults):
+class BooleanTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     @classmethod
     def setup_class(cls):
         global bool_table
@@ -1515,6 +1552,35 @@ class BooleanTest(fixtures.TestBase, AssertsExecutionResults):
         testing.db.execute(
             "insert into booltest (id, unconstrained_value) values (1, 5)")
 
+    def test_non_native_constraint_custom_type(self):
+        class Foob(object):
+            def __init__(self, value):
+                self.value = value
+
+        class MyBool(types.SchemaType, TypeDecorator):
+            impl = Boolean()
+
+            def _set_table(self, table, column):
+                self.impl._set_table(table, column)
+
+            # future method
+            def process_literal_param(self, value, dialect):
+                return value.value
+
+            def process_bind_param(self, value, dialect):
+                return value.value
+
+        m = MetaData()
+        t1 = Table('t', m, Column('x', MyBool()))
+        const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][0]
+
+        self.assert_compile(
+            AddConstraint(const),
+            "ALTER TABLE t ADD CHECK (x IN (0, 1))",
+            dialect="sqlite"
+        )
+
+
 class PickleTest(fixtures.TestBase):
     def test_eq_comparison(self):
         p1 = PickleType()