]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added new methods to TypeDecorator, process_bind_param() and
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Dec 2007 23:03:22 +0000 (23:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Dec 2007 23:03:22 +0000 (23:03 +0000)
process_result_value(), which automatically take advantage of the processing
of the underlying type.  Ideal for using with Unicode or Pickletype.
TypeDecorator should now be the primary way to augment the behavior of any
existing type including other TypeDecorator subclasses such as PickleType.

CHANGES
doc/build/content/types.txt
lib/sqlalchemy/types.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index 2e1160c20152791d01270515b2c90d45df92ba71..c9fa81e295cbcfb97e6897bfcd3667466e6f387f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,7 +30,13 @@ CHANGES
     
     - bindparam() objects themselves can be used as keys for execute(), i.e.
       statement.execute({bind1:'foo', bind2:'bar'})
-      
+    
+    - added new methods to TypeDecorator, process_bind_param() and
+      process_result_value(), which automatically take advantage of the processing
+      of the underlying type.  Ideal for using with Unicode or Pickletype.
+      TypeDecorator should now be the primary way to augment the behavior of any
+      existing type including other TypeDecorator subclasses such as PickleType.
+        
     - tables with schemas can still be used in sqlite, firebird,
       schema name just gets dropped [ticket:890]
 
index c1343f7b6cdfe0483df05f0694cfe18291fcab5d..0f88f4973a5925254e17d8f1b2ce8f2b98f5cd3a 100644 (file)
@@ -102,21 +102,29 @@ Or some postgres types:
 
 ### Creating your Own Types {@name=custom}
 
-User-defined types can be created, to support either database-specific types, or customized pre-processing of query parameters as well as post-processing of result set data.  You can make your own classes to perform these operations.  To augment the behavior of a `TypeEngine` type, such as `String`, the `TypeDecorator` class is used:
+User-defined types can be created which can augment the bind parameter and result processing capabilities of the built in types.  This is usually achieved using the `TypeDecorator` class, which "decorates" the behavior of any existing type.  As of version 0.4.2, the new `process_bind_param()` and `process_result_value()` methods should be used:
 
     {python}
     import sqlalchemy.types as types
 
     class MyType(types.TypeDecorator):
-        """basic type that decorates String, prefixes values with "PREFIX:" on 
+        """a type that decorates Unicode, prefixes values with "PREFIX:" on 
         the way in and strips it off on the way out."""
-        impl = types.String
-        def convert_bind_param(self, value, engine):
+        
+        impl = types.Unicode
+        
+        def process_bind_param(self, value, engine):
             return "PREFIX:" + value
-        def convert_result_value(self, value, engine):
-            return value[7:]
             
-The `PickleType` class is an instance of `TypeDecorator` already and can be subclassed directly.
+        def process_result_value(self, value, engine):
+            return value[7:]
+        
+        def copy(self):
+            return MyType(self.impl.length)
+
+Note that the "old" way to process bind params and result values, the `convert_bind_param()` and `convert_result_value()` methods, are still available.  The downside of these is that when using a type which already processes data such as the `Unicode` type, you need to call the superclass version of these methods directly.  Using `process_bind_param()` and `process_result_value()`, user-defined code can return and receive the desired Python data directly.
+
+As of version 0.4.2, `TypeDecorator` should generally be used for any user-defined type which redefines the behavior of another type, including other `TypeDecorator` subclasses such as `PickleType`, and the new `process_...()` methods described above should be used.  
 
 To build a type object from scratch, which will not have a corresponding database-specific implementation, subclass `TypeEngine`:
 
@@ -126,10 +134,13 @@ To build a type object from scratch, which will not have a corresponding databas
     class MyType(types.TypeEngine):
         def __init__(self, precision = 8):
             self.precision = precision
+            
         def get_col_spec(self):
             return "MYTYPE(%s)" % self.precision
+            
         def convert_bind_param(self, value, engine):
             return value
+            
         def convert_result_value(self, value, engine):
             return value
 
index 59b5282ec502d293b196a38daa03694d189a5636..fb54db7a3892124565521473bcf8cce9cd56c175 100644 (file)
@@ -197,7 +197,10 @@ class TypeDecorator(AbstractType):
         except KeyError:
             pass
 
-        typedesc = self.load_dialect_impl(dialect)
+        if isinstance(self.impl, TypeDecorator):
+            typedesc = self.impl.dialect_impl(dialect)
+        else:
+            typedesc = self.load_dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
             raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
@@ -211,7 +214,7 @@ class TypeDecorator(AbstractType):
         by default calls dialect.type_descriptor(self.impl), but
         can be overridden to provide different behavior.
         """
-
+        
         return dialect.type_descriptor(self.impl)
 
     def __getattr__(self, key):
@@ -222,11 +225,35 @@ class TypeDecorator(AbstractType):
     def get_col_spec(self):
         return self.impl.get_col_spec()
 
+    def process_bind_param(self, value, dialect):
+        raise NotImplementedError()
+    
+    def process_result_value(self, value, dialect):
+        raise NotImplementedError()
+        
     def bind_processor(self, dialect):
-        return self.impl.bind_processor(dialect)
+        if 'process_bind_param' in self.__class__.__dict__:
+            impl_processor = self.impl.bind_processor(dialect)
+            if impl_processor:
+                def process(value):
+                    return impl_processor(self.process_bind_param(value, dialect))
+                return process
+            else:
+                return self.process_bind_param
+        else:
+            return self.impl.bind_processor(dialect)
 
     def result_processor(self, dialect):
-        return self.impl.result_processor(dialect)
+        if 'process_result_value' in self.__class__.__dict__:
+            impl_processor = self.impl.result_processor(dialect)
+            if impl_processor:
+                def process(value):
+                    return self.process_result_value(impl_processor(value), dialect)
+                return process
+            else:
+                return self.process_result_value
+        else:
+            return self.impl.result_processor(dialect)
 
     def copy(self):
         instance = self.__class__.__new__(self.__class__)
index b8154d32e661f5ef88c3155968a950eee7ea1f01..5ea0921310507503196367bc4f81ca9ec477e2b2 100644 (file)
@@ -38,6 +38,18 @@ class MyDecoratedType(types.TypeDecorator):
     def copy(self):
         return MyDecoratedType()
 
+class MyNewUnicodeType(types.TypeDecorator):
+    impl = Unicode
+
+    def process_bind_param(self, value, dialect):
+        return "BIND_IN" + value
+
+    def process_result_value(self, value, dialect):
+        return value + "BIND_OUT"
+
+    def copy(self):
+        return MyNewUnicodeType(self.impl.length)
+
 class MyUnicodeType(types.TypeDecorator):
     impl = Unicode
 
@@ -45,18 +57,29 @@ class MyUnicodeType(types.TypeDecorator):
         impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value)
         
         def process(value):
-            return "UNI_BIND_IN"+ impl_processor(value)
+            return "BIND_IN"+ impl_processor(value)
         return process
 
     def result_processor(self, dialect):
         impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value)
         def process(value):
-            return impl_processor(value) + "UNI_BIND_OUT"
+            return impl_processor(value) + "BIND_OUT"
         return process
 
     def copy(self):
         return MyUnicodeType(self.impl.length)
 
+class MyPickleType(types.TypeDecorator):
+    impl = PickleType
+    
+    def process_bind_param(self, value, dialect):
+        value.stuff = 'this is modified stuff'
+        return value
+    
+    def process_result_value(self, value, dialect):
+        value.stuff = 'this is the right stuff'
+        return value
+        
 class LegacyType(types.TypeEngine):
     def get_col_spec(self):
         return "VARCHAR(100)"
@@ -71,10 +94,10 @@ class LegacyUnicodeType(types.TypeDecorator):
     impl = Unicode
 
     def convert_bind_param(self, value, dialect):
-        return "UNI_BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect)
+        return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect)
 
     def convert_result_value(self, value, dialect):
-        return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT"
+        return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT"
 
     def copy(self):
         return LegacyUnicodeType(self.impl.length)
@@ -178,17 +201,20 @@ class UserDefinedTest(PersistTest):
     def testprocessing(self):
 
         global users
-        users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack')
-        users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala')
-        users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred')
+        users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack')
+        users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala')
+        users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred')
 
         l = users.select().execute().fetchall()
-        assert l == [
-            (2, 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', 'BIND_INjackBIND_OUT'),
-            (3, 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT', 'BIND_INlalaBIND_OUT'),
-            (4, 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT', 'BIND_INfredBIND_OUT')
-        ]
-
+        for assertstr, row in zip(
+            ["BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", "BIND_INfredBIND_OUT"],
+            l
+        ):
+            for col in row[1:]:
+                self.assertEquals(col, assertstr)
+            for col in (row[4], row[5], row[7]):
+                assert isinstance(col, unicode)
+                
     def setUpAll(self):
         global users, metadata
         metadata = MetaData(testbase.db)
@@ -206,6 +232,7 @@ class UserDefinedTest(PersistTest):
             Column('goofy4', MyUnicodeType, nullable = False),
             Column('goofy5', LegacyUnicodeType, nullable = False),
             Column('goofy6', LegacyType, nullable = False),
+            Column('goofy7', MyNewUnicodeType, nullable = False),
 
         )
 
@@ -353,7 +380,8 @@ class BinaryTest(AssertMixin):
         # construct PickleType with non-native pickle module, since cPickle uses relative module
         # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
         # to the 'types' module
-        Column('pickled', PickleType)
+        Column('pickled', PickleType),
+        Column('mypickle', MyPickleType)
         )
         binary_table.create()
 
@@ -366,25 +394,28 @@ class BinaryTest(AssertMixin):
     def testbinary(self):
         testobj1 = pickleable.Foo('im foo 1')
         testobj2 = pickleable.Foo('im foo 2')
+        testobj3 = pickleable.Foo('im foo 3')
 
         stream1 =self.load_stream('binary_data_one.dat')
         stream2 =self.load_stream('binary_data_two.dat')
-        binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat',    data=stream1, data_slice=stream1[0:100], pickled=testobj1)
+        binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat',    data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3)
         binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
         binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None)
 
         for stmt in (
             binary_table.select(order_by=binary_table.c.primary_id),
-            text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db)
+            text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testbase.db)
         ):
             l = stmt.execute().fetchall()
             print type(stream1), type(l[0]['data']), type(l[0]['data_slice'])
             print len(stream1), len(l[0]['data']), len(l[0]['data_slice'])
-            self.assert_(list(stream1) == list(l[0]['data']))
-            self.assert_(list(stream1[0:100]) == list(l[0]['data_slice']))
-            self.assert_(list(stream2) == list(l[1]['data']))
-            self.assert_(testobj1 == l[0]['pickled'])
-            self.assert_(testobj2 == l[1]['pickled'])
+            self.assertEquals(list(stream1), list(l[0]['data']))
+            self.assertEquals(list(stream1[0:100]), list(l[0]['data_slice']))
+            self.assertEquals(list(stream2), list(l[1]['data']))
+            self.assertEquals(testobj1, l[0]['pickled'])
+            self.assertEquals(testobj2, l[1]['pickled'])
+            self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata)
+            self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff')
 
     def load_stream(self, name, len=12579):
         f = os.path.join(os.path.dirname(testbase.__file__), name)