]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- add a type_coerce() step within Enum, Boolean to the CHECK constraint,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Oct 2013 20:25:46 +0000 (16:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Oct 2013 20:31:45 +0000 (16:31 -0400)
so that the custom type isn't exposed to an operation that is against the
"impl" type's constraint, [ticket:2842]
- this change showed up as some recursion overflow in pickling with labels,
add a __reduce__() there....pickling of expressions is less and less something
that's very viable...

Conflicts:
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/sql/test_types.py

index 1ac48c7092839c89f0ce251cd9144248e5508f18..a6352a767c01c6aed622e5760ad5064bd1ad72ef 100644 (file)
 .. changelog::
     :version: 0.8.3
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 2842
+        :versions: 0.9.0
+
+        The :class:`.Enum` and :class:`.Boolean` types now bypass
+        any custom (e.g. TypeDecorator) type in use when producing the
+        CHECK constraint for the "non native" type.  This so that the custom type
+        isn't involved in the expression within the CHECK, since this
+        expression is against the "impl" value and not the "decorated" value.
+
     .. change::
         :tags: bug, postgresql
         :tickets: 2844
         Fixed bug in default compiler plus those of postgresql, mysql, and
         mssql to ensure that any literal SQL expression values are
         rendered directly as literals, instead of as bound parameters,
-        within a CREATE INDEX statement.
+        within a CREATE INDEX statement.  This also changes the rendering
+        scheme for other DDL such as constraints.
 
     .. change::
         :tags: bug, sql
index 5a97c2222b394af1350f66d0195f93fcea05e393..837716a41ccd7ad3b27be22d50b5c3fcebbce18b 100644 (file)
@@ -4347,6 +4347,9 @@ class Label(ColumnElement):
         self.quote = element.quote
         self._proxies = [element]
 
+    def __reduce__(self):
+        return self.__class__, (self.name, self._element, self._type)
+
     @util.memoized_property
     def type(self):
         return sqltypes.to_instance(
index 7ab3207bfa5cb5301deb6f5b97033e606ab8bb98..808114b8be6da1ae026ad7a8bc6cfb155cba5b86 100644 (file)
@@ -24,7 +24,7 @@ import datetime as dt
 import codecs
 
 from . import exc, schema, util, processors, events, event
-from .sql import operators
+from .sql import operators, type_coerce
 from .sql.expression import _DefaultColumnComparator
 from .util import pickle
 from .sql.visitors import Visitable
@@ -2061,7 +2061,7 @@ class Enum(String, SchemaType):
             SchemaType._set_table(self, column, table)
 
         e = schema.CheckConstraint(
-                        column.in_(self.enums),
+                        type_coerce(column, self).in_(self.enums),
                         name=self.name,
                         _create_rule=util.portable_instancemethod(
                                         self._should_create_constraint)
@@ -2198,7 +2198,7 @@ class Boolean(TypeEngine, SchemaType):
             return
 
         e = schema.CheckConstraint(
-                        column.in_([0, 1]),
+                        type_coerce(column, self).in_([0, 1]),
                         name=self.name,
                         _create_rule=util.portable_instancemethod(
                                     self._should_create_constraint)
index f80eb7134972c95b70f410413168a61e3f73e99a..325e957348df47dca264b03655c5dae50db8ba78 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import exc, types, util, dialects
 for name in dialects.__all__:
     __import__("sqlalchemy.dialects.%s" % name)
 from sqlalchemy.sql import operators, column, table
+from sqlalchemy.schema import CheckConstraint, AddConstraint
 from sqlalchemy.engine import default
 from sqlalchemy.testing.schema import Table, Column
 from sqlalchemy import testing
@@ -787,7 +788,7 @@ class UnicodeTest(fixtures.TestBase):
         )
 
 
-class EnumTest(fixtures.TestBase):
+class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
     @classmethod
     def setup_class(cls):
         global enum_table, non_native_enum_table, metadata
@@ -870,6 +871,42 @@ class EnumTest(fixtures.TestBase):
             {'id': 4, 'someenum': 'four'}
         )
 
+    def test_non_native_constraint_custom_type(self):
+        class Foob(object):
+            def __init__(self, name):
+                self.name = name
+
+        class MyEnum(types.SchemaType, TypeDecorator):
+            def __init__(self, values):
+                self.impl = Enum(
+                                *[v.name for v in values],
+                                name="myenum",
+                                native_enum=False
+                            )
+
+
+            def _set_table(self, table, column):
+                self.impl._set_table(table, column)
+
+            # future method
+            def process_literal_param(self, value, dialect):
+                return value.name
+
+            def process_bind_param(self, value, dialect):
+                return value.name
+
+        m = MetaData()
+        t1 = Table('t', m, Column('x', MyEnum([Foob('a'), Foob('b')])))
+        const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][0]
+
+        self.assert_compile(
+            AddConstraint(const),
+            "ALTER TABLE t ADD CONSTRAINT myenum CHECK (x IN ('a', 'b'))",
+            dialect=default.DefaultDialect()
+        )
+
+
+
     @testing.fails_on('mysql',
                     "the CHECK constraint doesn't raise an exception for unknown reason")
     def test_non_native_constraint(self):
@@ -1472,7 +1509,7 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(row['non_native_interval'], None)
 
 
-class BooleanTest(fixtures.TestBase, AssertsExecutionResults):
+class BooleanTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     @classmethod
     def setup_class(cls):
         global bool_table
@@ -1534,6 +1571,35 @@ class BooleanTest(fixtures.TestBase, AssertsExecutionResults):
         testing.db.execute(
             "insert into booltest (id, unconstrained_value) values (1, 5)")
 
+    def test_non_native_constraint_custom_type(self):
+        class Foob(object):
+            def __init__(self, value):
+                self.value = value
+
+        class MyBool(types.SchemaType, TypeDecorator):
+            impl = Boolean()
+
+            def _set_table(self, table, column):
+                self.impl._set_table(table, column)
+
+            # future method
+            def process_literal_param(self, value, dialect):
+                return value.value
+
+            def process_bind_param(self, value, dialect):
+                return value.value
+
+        m = MetaData()
+        t1 = Table('t', m, Column('x', MyBool()))
+        const = [c for c in t1.constraints if isinstance(c, CheckConstraint)][0]
+
+        self.assert_compile(
+            AddConstraint(const),
+            "ALTER TABLE t ADD CHECK (x IN (0, 1))",
+            dialect=dialects.sqlite.dialect()
+        )
+
+
 class PickleTest(fixtures.TestBase):
     def test_eq_comparison(self):
         p1 = PickleType()