From 4cff2c1a3c6600376f2ce09692df233928a9f4f7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 29 Apr 2006 01:05:13 +0000 Subject: [PATCH] *another* big types change....the old way was still wrong...this way is better (still need to go through it again since i am apparently type-impaired....) --- lib/sqlalchemy/databases/mysql.py | 6 ++- lib/sqlalchemy/types.py | 65 +++++++++++++++++-------------- test/testtypes.py | 42 ++++++++++++++++++-- 3 files changed, 79 insertions(+), 34 deletions(-) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index a25a21e9bf..60435f2203 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -79,6 +79,11 @@ class MSBinary(sqltypes.Binary): return "BINARY(%d)" % self.length else: return "BLOB" + def convert_result_value(self, value, engine): + if value is None: + return None + else: + return buffer(value) class MSBoolean(sqltypes.Boolean): def get_col_spec(self): @@ -142,7 +147,6 @@ class MySQLEngine(ansisql.ANSISQLEngine): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def last_inserted_ids(self): return self.context.last_inserted_ids diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 7a3822a651..2b80fc63ec 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -17,9 +17,7 @@ try: except: import pickle -class TypeEngine(object): - def __init__(self, *args, **kwargs): - pass +class AbstractType(object): def _get_impl_dict(self): try: return self._impl_dict @@ -27,32 +25,45 @@ class TypeEngine(object): self._impl_dict = {} return self._impl_dict impl_dict = property(_get_impl_dict) + def get_constructor_args(self): + return {} + def adapt_args(self): + return self + +class TypeEngine(AbstractType): + def __init__(self, *args, **params): + pass def engine_impl(self, engine): try: return self.impl_dict[engine] except: return self.impl_dict.setdefault(engine, engine.type_descriptor(self)) - def _get_impl(self): - if hasattr(self, '_impl'): - return self._impl - else: - return NULLTYPE - def _set_impl(self, impl): - self._impl = impl - impl = property(_get_impl, _set_impl) + def get_col_spec(self): + raise NotImplementedError() + def convert_bind_param(self, value, engine): + return value + def convert_result_value(self, value, engine): + return value + +class TypeDecorator(AbstractType): + def __init__(self, *args, **params): + pass + def engine_impl(self, engine): + try: + return self.impl_dict[engine] + except: + typedesc = engine.type_descriptor(self.impl) + tt = self.__class__(**self.get_constructor_args()) + tt.impl = typedesc + self.impl_dict[engine] = tt + return tt def get_col_spec(self): return self.impl.get_col_spec() def convert_bind_param(self, value, engine): return self.impl.convert_bind_param(value, engine) def convert_result_value(self, value, engine): return self.impl.convert_result_value(value, engine) - def set_impl(self, impltype): - self.impl = impltype(**self.get_constructor_args()) - def get_constructor_args(self): - return {} - def adapt_args(self): - return self - + def to_instance(typeobj): if typeobj is None: return NULLTYPE @@ -73,9 +84,7 @@ def adapt_type(typeobj, colspecs): else: # couldnt adapt...raise exception ? return typeobj - typeobj.set_impl(impltype) - typeobj.impl.impl = NULLTYPE - return typeobj + return impltype(**t2.get_constructor_args()) class NullTypeEngine(TypeEngine): def get_col_spec(self): @@ -85,10 +94,6 @@ class NullTypeEngine(TypeEngine): def convert_result_value(self, value, engine): return value -class TypeDecorator(object): - """TypeDecorator is deprecated""" - pass - class String(TypeEngine): def __init__(self, length = None): @@ -111,7 +116,8 @@ class String(TypeEngine): else: return self -class Unicode(String): +class Unicode(TypeDecorator): + impl = String def convert_bind_param(self, value, engine): if value is not None and isinstance(value, unicode): return value.encode(engine.encoding) @@ -164,19 +170,20 @@ class Binary(TypeEngine): def get_constructor_args(self): return {'length':self.length} -class PickleType(Binary): +class PickleType(TypeDecorator): def __init__(self, protocol=pickle.HIGHEST_PROTOCOL): """allows the pickle protocol to be specified""" self.protocol = protocol + self.impl = Binary() def convert_result_value(self, value, engine): if value is None: return None - buf = Binary.convert_result_value(self, value, engine) + buf = self.impl.convert_result_value(value, engine) return pickle.loads(str(buf)) def convert_bind_param(self, value, engine): if value is None: return None - return Binary.convert_bind_param(self, pickle.dumps(value, self.protocol), engine) + return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine) def get_constructor_args(self): return {} diff --git a/test/testtypes.py b/test/testtypes.py index c2e3043c99..3903b3a816 100644 --- a/test/testtypes.py +++ b/test/testtypes.py @@ -17,7 +17,8 @@ class MyType(types.TypeEngine): def adapt_args(self): return self -class MyDecoratedType(types.String): +class MyDecoratedType(types.TypeDecorator): + impl = String def convert_bind_param(self, value, engine): return "BIND_IN"+ value def convert_result_value(self, value, engine): @@ -29,6 +30,21 @@ class MyUnicodeType(types.Unicode): def convert_result_value(self, value, engine): return value + "UNI_BIND_OUT" +class AdaptTest(PersistTest): + def testadapt(self): + e1 = create_engine('postgres://') + e2 = create_engine('sqlite://') + e3 = create_engine('mysql://') + + type = String(40) + + t1 = type.engine_impl(e1) + t2 = type.engine_impl(e2) + t3 = type.engine_impl(e3) + assert t1 != t2 + assert t2 != t3 + assert t3 != t1 + class OverrideTest(PersistTest): """tests user-defined types, including a full type as well as a TypeDecorator""" @@ -132,6 +148,15 @@ class UnicodeTest(AssertMixin): self.assert_(isinstance(x['plain_data'], unicode) and x['plain_data'] == unicodedata) finally: db.engine.convert_unicode = prev_unicode + + +class Foo(object): + def __init__(self, moredata): + self.data = 'im data' + self.stuff = 'im stuff' + self.moredata = moredata + def __eq__(self, other): + return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata class BinaryTest(AssertMixin): def setUpAll(self): @@ -140,20 +165,29 @@ class BinaryTest(AssertMixin): Column('primary_id', Integer, primary_key=True), Column('data', Binary), Column('data_slice', Binary(100)), - Column('misc', String(30))) + Column('misc', String(30)), + Column('pickled', PickleType)) binary_table.create() def tearDownAll(self): binary_table.drop() def testbinary(self): + testobj1 = Foo('im foo 1') + testobj2 = Foo('im foo 2') + stream1 =self.get_module_stream('sqlalchemy.sql') stream2 =self.get_module_stream('sqlalchemy.engine') - binary_table.insert().execute(primary_id=1, misc='sql.pyc', data=stream1, data_slice=stream1[0:100]) - binary_table.insert().execute(primary_id=2, misc='engine.pyc', data=stream2, data_slice=stream2[0:99]) + binary_table.insert().execute(primary_id=1, misc='sql.pyc', data=stream1, data_slice=stream1[0:100], pickled=testobj1) + binary_table.insert().execute(primary_id=2, misc='engine.pyc', data=stream2, data_slice=stream2[0:99], pickled=testobj2) l = binary_table.select().execute().fetchall() + print type(l[0]['data']) + return print len(stream1), len(l[0]['data']), len(l[0]['data_slice']) self.assert_(list(stream1) == list(l[0]['data'])) self.assert_(list(stream1[0:100]) == list(l[0]['data_slice'])) self.assert_(list(stream2) == list(l[1]['data'])) + self.assert_(testobj1 == l[0]['pickled']) + self.assert_(testobj2 == l[1]['pickled']) + def get_module_stream(self, name): mod = __import__(name) for token in name.split('.')[1:]: -- 2.47.2