]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in TypeDecorator whereby the dialect-specific
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Oct 2010 20:42:32 +0000 (16:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Oct 2010 20:42:32 +0000 (16:42 -0400)
type was getting pulled in to generate the DDL for a
given type, which didn't always return the correct result.

- TypeDecorator can now have a fully constructed type
specified as its "impl", in addition to a type class.

- TypeDecorator will now place itself as the resulting
type for a binary expression where the type coercion
rules would normally return its impl type - previously,
a copy of the impl type would be returned which would
have the TypeDecorator embedded into it as the "dialect"
impl, this was probably an unintentional way of achieving
the desired effect.

- TypeDecorator.load_dialect_impl() returns "self.impl" by
default, i.e. not the dialect implementation type of
"self.impl".   This to support compilation correctly.
Behavior can be user-overridden in exactly the same way
as before to the same effect.

CHANGES
lib/sqlalchemy/types.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 375ce905950ecdf6998ea67cea33f9d402cf5745..9eb89b07bf9649f358cd46e0634c54754b5b8bae 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -131,6 +131,27 @@ CHANGES
     [ticket:1932]
     
 - sql
+   - Fixed bug in TypeDecorator whereby the dialect-specific
+     type was getting pulled in to generate the DDL for a 
+     given type, which didn't always return the correct result.
+     
+   - TypeDecorator can now have a fully constructed type
+     specified as its "impl", in addition to a type class.
+
+   - TypeDecorator will now place itself as the resulting
+     type for a binary expression where the type coercion
+     rules would normally return its impl type - previously,
+     a copy of the impl type would be returned which would
+     have the TypeDecorator embedded into it as the "dialect"
+     impl, this was probably an unintentional way of achieving
+     the desired effect.
+
+   - TypeDecorator.load_dialect_impl() returns "self.impl" by
+     default, i.e. not the dialect implementation type of 
+     "self.impl".   This to support compilation correctly.  
+     Behavior can be user-overridden in exactly the same way 
+     as before to the same effect.
+     
    - Table.tometadata() now copies Index objects associated
      with the Table as well.
 
index 61bff6ea6d7fa4752ceae89d80ac3c6529e77ced..ee1fdc67f57b15147a728228e3962b5c41b63d74 100644 (file)
@@ -356,21 +356,19 @@ class TypeDecorator(AbstractType):
                                  "require a class-level variable "
                                  "'impl' which refers to the class of "
                                  "type being decorated")
-        self.impl = self.__class__.impl(*args, **kwargs)
+        self.impl = to_instance(self.__class__.impl, *args, **kwargs)
     
     def adapt(self, cls):
         return cls()
         
     def dialect_impl(self, dialect):
         key = (dialect.__class__, dialect.server_version_info)
+
         try:
             return self._impl_dict[key]
         except KeyError:
             pass
 
-        # adapt the TypeDecorator first, in
-        # the case that the dialect maps the TD
-        # to one of its native types (i.e. PGInterval)
         adapted = dialect.type_descriptor(self)
         if adapted is not self:
             self._impl_dict[key] = adapted
@@ -379,7 +377,7 @@ class TypeDecorator(AbstractType):
         # otherwise adapt the impl type, link
         # to a copy of this TypeDecorator and return
         # that.
-        typedesc = self.load_dialect_impl(dialect)
+        typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
             raise AssertionError('Type object %s does not properly '
@@ -390,28 +388,34 @@ class TypeDecorator(AbstractType):
         self._impl_dict[key] = tt
         return tt
 
+    @util.memoized_property
+    def _impl_dict(self):
+        return {}
+
     @util.memoized_property
     def _type_affinity(self):
         return self.impl._type_affinity
 
     def type_engine(self, dialect):
-        impl = self.dialect_impl(dialect)
-        if not isinstance(impl, TypeDecorator):
-            return impl
+        """Return a TypeEngine instance for this TypeDecorator.
+        
+        """
+        adapted = dialect.type_descriptor(self)
+        if adapted is not self:
+            return adapted
+        elif isinstance(self.impl, TypeDecorator):
+            return self.impl.type_engine(dialect)
         else:
-            return impl.impl
+            return self.load_dialect_impl(dialect)
 
     def load_dialect_impl(self, dialect):
-        """Loads the dialect-specific implementation of this type.
+        """User hook which can be overridden to provide a different 'impl'
+        type per-dialect.
 
-        by default calls dialect.type_descriptor(self.impl), but
-        can be overridden to provide different behavior.
+        by default returns self.impl.
 
         """
-        if isinstance(self.impl, TypeDecorator):
-            return self.impl.dialect_impl(dialect)
-        else:
-            return dialect.type_descriptor(self.impl)
+        return self.impl
 
     def __getattr__(self, key):
         """Proxy all other undefined accessors to the underlying
@@ -513,9 +517,11 @@ class TypeDecorator(AbstractType):
         return self.impl.is_mutable()
 
     def _adapt_expression(self, op, othertype):
-        return self.impl._adapt_expression(op, othertype)
-
-
+        op, typ =self.impl._adapt_expression(op, othertype)
+        if typ is self.impl:
+            return op, self
+        else:
+            return op, typ
 
 class MutableType(object):
     """A mixin that marks a :class:`TypeEngine` as representing
@@ -603,12 +609,12 @@ class MutableType(object):
         """Compare *x* == *y*."""
         return x == y
 
-def to_instance(typeobj):
+def to_instance(typeobj, *arg, **kw):
     if typeobj is None:
         return NULLTYPE
 
     if util.callable(typeobj):
-        return typeobj()
+        return typeobj(*arg, **kw)
     else:
         return typeobj
 
index ad8af31a0122514ebba20e02f89e590cadc194dd..437d69fffa0fde5bf4b9e83ddb50f171fe09a139 100644 (file)
@@ -118,7 +118,7 @@ class PickleMetadataTest(TestBase):
                 mt = loads(dumps(meta))
                 
 
-class UserDefinedTest(TestBase):
+class UserDefinedTest(TestBase, AssertsCompiledSQL):
     """tests user-defined types."""
 
     def test_processing(self):
@@ -148,6 +148,60 @@ class UserDefinedTest(TestBase):
             for col in row[3], row[4]:
                 assert isinstance(col, unicode)
 
+    def test_typedecorator_impl(self):
+        for impl_, exp, kw in [
+            (Float, "FLOAT", {}),
+            (Float, "FLOAT(2)", {'precision':2}),
+            (Float(2), "FLOAT(2)", {'precision':4}),
+            (Numeric(19, 2), "NUMERIC(19, 2)", {}),
+        ]:
+            for dialect_ in (postgresql, mssql, mysql):
+                dialect_ = dialect_.dialect()
+                
+                raw_impl = types.to_instance(impl_, **kw)
+                
+                class MyType(types.TypeDecorator):
+                    impl = impl_
+                
+                dec_type = MyType(**kw)
+                
+                eq_(dec_type.impl.__class__, raw_impl.__class__)
+                
+                raw_dialect_impl = raw_impl.dialect_impl(dialect_)
+                dec_dialect_impl = dec_type.dialect_impl(dialect_)
+                eq_(dec_dialect_impl.__class__, MyType)
+                eq_(raw_dialect_impl.__class__ , dec_dialect_impl.impl.__class__)
+                
+                self.assert_compile(
+                    MyType(**kw),
+                    exp,
+                    dialect=dialect_
+                )
+    
+    def test_user_defined_typedec_impl(self):
+        class MyType(types.TypeDecorator):
+            impl = Float
+            
+            def load_dialect_impl(self, dialect):
+                if dialect.name == 'sqlite':
+                    return String(50)
+                else:
+                    return super(MyType, self).load_dialect_impl(dialect)
+        
+        sl = sqlite.dialect()
+        pg = postgresql.dialect()
+        t = MyType()
+        self.assert_compile(t, "VARCHAR(50)", dialect=sl)
+        self.assert_compile(t, "FLOAT", dialect=pg)
+        eq_(
+            t.dialect_impl(dialect=sl).impl.__class__, 
+            String().dialect_impl(dialect=sl).__class__
+        )
+        eq_(
+                t.dialect_impl(dialect=pg).impl.__class__, 
+                Float().dialect_impl(pg).__class__
+        )
+                
     @classmethod
     def setup_class(cls):
         global users, metadata
@@ -838,8 +892,9 @@ class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     def test_typedec_operator_adapt(self):
         expr = test_table.c.bvalue + "hi"
         
-        assert expr.type.__class__ is String
-
+        assert expr.type.__class__ is MyTypeDec
+        assert expr.right.type.__class__ is MyTypeDec
+        
         eq_(
             testing.db.execute(select([expr.label('foo')])).scalar(),
             "BIND_INfooBIND_INhiBIND_OUT"
@@ -864,7 +919,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             use_default_dialect=True
         )
         
-        assert expr.type.__class__ is String
+        assert expr.type.__class__ is MyTypeDec
         eq_(
             testing.db.execute(select([expr.label('foo')])).scalar(),
             "BIND_INfooBIND_IN6BIND_OUT"
@@ -944,8 +999,6 @@ class ExpressionTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             "a + b"
         )
         
-
-        
     def test_expression_typing(self):
         expr = column('bar', Integer) - 3