From c9b3f0bcef20794ac7296a855aafe8b75ae7630e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 8 Dec 2007 23:03:22 +0000 Subject: [PATCH] - added new methods to TypeDecorator, process_bind_param() and process_result_value(), which automatically take advantage of the processing of the underlying type. Ideal for using with Unicode or Pickletype. TypeDecorator should now be the primary way to augment the behavior of any existing type including other TypeDecorator subclasses such as PickleType. --- CHANGES | 8 +++- doc/build/content/types.txt | 25 +++++++++---- lib/sqlalchemy/types.py | 35 ++++++++++++++++-- test/sql/testtypes.py | 73 ++++++++++++++++++++++++++----------- 4 files changed, 108 insertions(+), 33 deletions(-) diff --git a/CHANGES b/CHANGES index 2e1160c201..c9fa81e295 100644 --- a/CHANGES +++ b/CHANGES @@ -30,7 +30,13 @@ CHANGES - bindparam() objects themselves can be used as keys for execute(), i.e. statement.execute({bind1:'foo', bind2:'bar'}) - + + - added new methods to TypeDecorator, process_bind_param() and + process_result_value(), which automatically take advantage of the processing + of the underlying type. Ideal for using with Unicode or Pickletype. + TypeDecorator should now be the primary way to augment the behavior of any + existing type including other TypeDecorator subclasses such as PickleType. + - tables with schemas can still be used in sqlite, firebird, schema name just gets dropped [ticket:890] diff --git a/doc/build/content/types.txt b/doc/build/content/types.txt index c1343f7b6c..0f88f4973a 100644 --- a/doc/build/content/types.txt +++ b/doc/build/content/types.txt @@ -102,21 +102,29 @@ Or some postgres types: ### Creating your Own Types {@name=custom} -User-defined types can be created, to support either database-specific types, or customized pre-processing of query parameters as well as post-processing of result set data. You can make your own classes to perform these operations. To augment the behavior of a `TypeEngine` type, such as `String`, the `TypeDecorator` class is used: +User-defined types can be created which can augment the bind parameter and result processing capabilities of the built in types. This is usually achieved using the `TypeDecorator` class, which "decorates" the behavior of any existing type. As of version 0.4.2, the new `process_bind_param()` and `process_result_value()` methods should be used: {python} import sqlalchemy.types as types class MyType(types.TypeDecorator): - """basic type that decorates String, prefixes values with "PREFIX:" on + """a type that decorates Unicode, prefixes values with "PREFIX:" on the way in and strips it off on the way out.""" - impl = types.String - def convert_bind_param(self, value, engine): + + impl = types.Unicode + + def process_bind_param(self, value, engine): return "PREFIX:" + value - def convert_result_value(self, value, engine): - return value[7:] -The `PickleType` class is an instance of `TypeDecorator` already and can be subclassed directly. + def process_result_value(self, value, engine): + return value[7:] + + def copy(self): + return MyType(self.impl.length) + +Note that the "old" way to process bind params and result values, the `convert_bind_param()` and `convert_result_value()` methods, are still available. The downside of these is that when using a type which already processes data such as the `Unicode` type, you need to call the superclass version of these methods directly. Using `process_bind_param()` and `process_result_value()`, user-defined code can return and receive the desired Python data directly. + +As of version 0.4.2, `TypeDecorator` should generally be used for any user-defined type which redefines the behavior of another type, including other `TypeDecorator` subclasses such as `PickleType`, and the new `process_...()` methods described above should be used. To build a type object from scratch, which will not have a corresponding database-specific implementation, subclass `TypeEngine`: @@ -126,10 +134,13 @@ To build a type object from scratch, which will not have a corresponding databas class MyType(types.TypeEngine): def __init__(self, precision = 8): self.precision = precision + def get_col_spec(self): return "MYTYPE(%s)" % self.precision + def convert_bind_param(self, value, engine): return value + def convert_result_value(self, value, engine): return value diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 59b5282ec5..fb54db7a38 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -197,7 +197,10 @@ class TypeDecorator(AbstractType): except KeyError: pass - typedesc = self.load_dialect_impl(dialect) + if isinstance(self.impl, TypeDecorator): + typedesc = self.impl.dialect_impl(dialect) + else: + typedesc = self.load_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) @@ -211,7 +214,7 @@ class TypeDecorator(AbstractType): by default calls dialect.type_descriptor(self.impl), but can be overridden to provide different behavior. """ - + return dialect.type_descriptor(self.impl) def __getattr__(self, key): @@ -222,11 +225,35 @@ class TypeDecorator(AbstractType): def get_col_spec(self): return self.impl.get_col_spec() + def process_bind_param(self, value, dialect): + raise NotImplementedError() + + def process_result_value(self, value, dialect): + raise NotImplementedError() + def bind_processor(self, dialect): - return self.impl.bind_processor(dialect) + if 'process_bind_param' in self.__class__.__dict__: + impl_processor = self.impl.bind_processor(dialect) + if impl_processor: + def process(value): + return impl_processor(self.process_bind_param(value, dialect)) + return process + else: + return self.process_bind_param + else: + return self.impl.bind_processor(dialect) def result_processor(self, dialect): - return self.impl.result_processor(dialect) + if 'process_result_value' in self.__class__.__dict__: + impl_processor = self.impl.result_processor(dialect) + if impl_processor: + def process(value): + return self.process_result_value(impl_processor(value), dialect) + return process + else: + return self.process_result_value + else: + return self.impl.result_processor(dialect) def copy(self): instance = self.__class__.__new__(self.__class__) diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index b8154d32e6..5ea0921310 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -38,6 +38,18 @@ class MyDecoratedType(types.TypeDecorator): def copy(self): return MyDecoratedType() +class MyNewUnicodeType(types.TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + return "BIND_IN" + value + + def process_result_value(self, value, dialect): + return value + "BIND_OUT" + + def copy(self): + return MyNewUnicodeType(self.impl.length) + class MyUnicodeType(types.TypeDecorator): impl = Unicode @@ -45,18 +57,29 @@ class MyUnicodeType(types.TypeDecorator): impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value) def process(value): - return "UNI_BIND_IN"+ impl_processor(value) + return "BIND_IN"+ impl_processor(value) return process def result_processor(self, dialect): impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value) def process(value): - return impl_processor(value) + "UNI_BIND_OUT" + return impl_processor(value) + "BIND_OUT" return process def copy(self): return MyUnicodeType(self.impl.length) +class MyPickleType(types.TypeDecorator): + impl = PickleType + + def process_bind_param(self, value, dialect): + value.stuff = 'this is modified stuff' + return value + + def process_result_value(self, value, dialect): + value.stuff = 'this is the right stuff' + return value + class LegacyType(types.TypeEngine): def get_col_spec(self): return "VARCHAR(100)" @@ -71,10 +94,10 @@ class LegacyUnicodeType(types.TypeDecorator): impl = Unicode def convert_bind_param(self, value, dialect): - return "UNI_BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect) + return "BIND_IN" + super(LegacyUnicodeType, self).convert_bind_param(value, dialect) def convert_result_value(self, value, dialect): - return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT" + return super(LegacyUnicodeType, self).convert_result_value(value, dialect) + "BIND_OUT" def copy(self): return LegacyUnicodeType(self.impl.length) @@ -178,17 +201,20 @@ class UserDefinedTest(PersistTest): def testprocessing(self): global users - users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack') - users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala') - users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred') + users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack') + users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala') + users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred') l = users.select().execute().fetchall() - assert l == [ - (2, 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', 'BIND_INjackBIND_OUT'), - (3, 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT', 'BIND_INlalaBIND_OUT'), - (4, 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT', 'BIND_INfredBIND_OUT') - ] - + for assertstr, row in zip( + ["BIND_INjackBIND_OUT", "BIND_INlalaBIND_OUT", "BIND_INfredBIND_OUT"], + l + ): + for col in row[1:]: + self.assertEquals(col, assertstr) + for col in (row[4], row[5], row[7]): + assert isinstance(col, unicode) + def setUpAll(self): global users, metadata metadata = MetaData(testbase.db) @@ -206,6 +232,7 @@ class UserDefinedTest(PersistTest): Column('goofy4', MyUnicodeType, nullable = False), Column('goofy5', LegacyUnicodeType, nullable = False), Column('goofy6', LegacyType, nullable = False), + Column('goofy7', MyNewUnicodeType, nullable = False), ) @@ -353,7 +380,8 @@ class BinaryTest(AssertMixin): # construct PickleType with non-native pickle module, since cPickle uses relative module # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative # to the 'types' module - Column('pickled', PickleType) + Column('pickled', PickleType), + Column('mypickle', MyPickleType) ) binary_table.create() @@ -366,25 +394,28 @@ class BinaryTest(AssertMixin): def testbinary(self): testobj1 = pickleable.Foo('im foo 1') testobj2 = pickleable.Foo('im foo 2') + testobj3 = pickleable.Foo('im foo 3') stream1 =self.load_stream('binary_data_one.dat') stream2 =self.load_stream('binary_data_two.dat') - binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1) + binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2) binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None) for stmt in ( binary_table.select(order_by=binary_table.c.primary_id), - text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db) + text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testbase.db) ): l = stmt.execute().fetchall() print type(stream1), type(l[0]['data']), type(l[0]['data_slice']) print len(stream1), len(l[0]['data']), len(l[0]['data_slice']) - self.assert_(list(stream1) == list(l[0]['data'])) - self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) - self.assert_(list(stream2) == list(l[1]['data'])) - self.assert_(testobj1 == l[0]['pickled']) - self.assert_(testobj2 == l[1]['pickled']) + self.assertEquals(list(stream1), list(l[0]['data'])) + self.assertEquals(list(stream1[0:100]), list(l[0]['data_slice'])) + self.assertEquals(list(stream2), list(l[1]['data'])) + self.assertEquals(testobj1, l[0]['pickled']) + self.assertEquals(testobj2, l[1]['pickled']) + self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata) + self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff') def load_stream(self, name, len=12579): f = os.path.join(os.path.dirname(testbase.__file__), name) -- 2.47.3