]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
*another* big types change....the old way was still wrong...this way is better (still...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Apr 2006 01:05:13 +0000 (01:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Apr 2006 01:05:13 +0000 (01:05 +0000)
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/types.py
test/testtypes.py

index a25a21e9bf06ee2de09dd331592c8fc8a50f880e..60435f22039f7cc4554046296466729bff683929 100644 (file)
@@ -79,6 +79,11 @@ class MSBinary(sqltypes.Binary):
             return "BINARY(%d)" % self.length
         else:
             return "BLOB"
+    def convert_result_value(self, value, engine):
+        if value is None:
+            return None
+        else:
+            return buffer(value)
 
 class MSBoolean(sqltypes.Boolean):
     def get_col_spec(self):
@@ -142,7 +147,6 @@ class MySQLEngine(ansisql.ANSISQLEngine):
 
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
-
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
index 7a3822a6518e4458420e379b3243dfdfcea1c852..2b80fc63ec9e93b6095594dfefc360d503953f0b 100644 (file)
@@ -17,9 +17,7 @@ try:
 except:
     import pickle
 
-class TypeEngine(object):
-    def __init__(self, *args, **kwargs):
-        pass
+class AbstractType(object):
     def _get_impl_dict(self):
         try:
             return self._impl_dict
@@ -27,32 +25,45 @@ class TypeEngine(object):
             self._impl_dict = {}
             return self._impl_dict
     impl_dict = property(_get_impl_dict)
+    def get_constructor_args(self):
+        return {}
+    def adapt_args(self):
+        return self
+
+class TypeEngine(AbstractType):
+    def __init__(self, *args, **params):
+        pass
     def engine_impl(self, engine):
         try:
             return self.impl_dict[engine]
         except:
             return self.impl_dict.setdefault(engine, engine.type_descriptor(self))
-    def _get_impl(self):
-        if hasattr(self, '_impl'):
-            return self._impl
-        else:
-            return NULLTYPE
-    def _set_impl(self, impl):
-        self._impl = impl
-    impl = property(_get_impl, _set_impl)
+    def get_col_spec(self):
+        raise NotImplementedError()
+    def convert_bind_param(self, value, engine):
+        return value
+    def convert_result_value(self, value, engine):
+        return value
+
+class TypeDecorator(AbstractType):
+    def __init__(self, *args, **params):
+        pass
+    def engine_impl(self, engine):
+        try:
+            return self.impl_dict[engine]
+        except:
+            typedesc = engine.type_descriptor(self.impl)
+            tt = self.__class__(**self.get_constructor_args())
+            tt.impl = typedesc
+            self.impl_dict[engine] = tt
+            return tt
     def get_col_spec(self):
         return self.impl.get_col_spec()
     def convert_bind_param(self, value, engine):
         return self.impl.convert_bind_param(value, engine)
     def convert_result_value(self, value, engine):
         return self.impl.convert_result_value(value, engine)
-    def set_impl(self, impltype):
-        self.impl = impltype(**self.get_constructor_args())
-    def get_constructor_args(self):
-        return {}
-    def adapt_args(self):
-        return self
-
+        
 def to_instance(typeobj):
     if typeobj is None:
         return NULLTYPE
@@ -73,9 +84,7 @@ def adapt_type(typeobj, colspecs):
     else:
         # couldnt adapt...raise exception ?
         return typeobj
-    typeobj.set_impl(impltype)
-    typeobj.impl.impl = NULLTYPE
-    return typeobj
+    return impltype(**t2.get_constructor_args())
     
 class NullTypeEngine(TypeEngine):
     def get_col_spec(self):
@@ -85,10 +94,6 @@ class NullTypeEngine(TypeEngine):
     def convert_result_value(self, value, engine):
         return value
 
-class TypeDecorator(object):
-    """TypeDecorator is deprecated"""
-    pass
-    
     
 class String(TypeEngine):
     def __init__(self, length = None):
@@ -111,7 +116,8 @@ class String(TypeEngine):
         else:
             return self
             
-class Unicode(String):
+class Unicode(TypeDecorator):
+    impl = String
     def convert_bind_param(self, value, engine):
          if value is not None and isinstance(value, unicode):
               return value.encode(engine.encoding)
@@ -164,19 +170,20 @@ class Binary(TypeEngine):
     def get_constructor_args(self):
         return {'length':self.length}
 
-class PickleType(Binary):
+class PickleType(TypeDecorator):
       def __init__(self, protocol=pickle.HIGHEST_PROTOCOL):
            """allows the pickle protocol to be specified"""
            self.protocol = protocol
+           self.impl = Binary()
       def convert_result_value(self, value, engine):
           if value is None:
               return None
-          buf = Binary.convert_result_value(self, value, engine)
+          buf = self.impl.convert_result_value(value, engine)
           return pickle.loads(str(buf))
       def convert_bind_param(self, value, engine):
           if value is None:
               return None
-          return Binary.convert_bind_param(self, pickle.dumps(value, self.protocol), engine)
+          return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine)
       def get_constructor_args(self):
             return {}
 
index c2e3043c999b00b3162b716ba752abe101265545..3903b3a816406ed8acedde7d2e219a310a6eedd1 100644 (file)
@@ -17,7 +17,8 @@ class MyType(types.TypeEngine):
     def adapt_args(self):
         return self
 
-class MyDecoratedType(types.String):
+class MyDecoratedType(types.TypeDecorator):
+    impl = String
     def convert_bind_param(self, value, engine):
         return "BIND_IN"+ value
     def convert_result_value(self, value, engine):
@@ -29,6 +30,21 @@ class MyUnicodeType(types.Unicode):
     def convert_result_value(self, value, engine):
         return value + "UNI_BIND_OUT"
 
+class AdaptTest(PersistTest):
+    def testadapt(self):
+        e1 = create_engine('postgres://')
+        e2 = create_engine('sqlite://')
+        e3 = create_engine('mysql://')
+        
+        type = String(40)
+        
+        t1 = type.engine_impl(e1)
+        t2 = type.engine_impl(e2)
+        t3 = type.engine_impl(e3)
+        assert t1 != t2
+        assert t2 != t3
+        assert t3 != t1
+        
 class OverrideTest(PersistTest):
     """tests user-defined types, including a full type as well as a TypeDecorator"""
 
@@ -132,6 +148,15 @@ class UnicodeTest(AssertMixin):
             self.assert_(isinstance(x['plain_data'], unicode) and x['plain_data'] == unicodedata)
         finally:
             db.engine.convert_unicode = prev_unicode
+
+
+class Foo(object):
+    def __init__(self, moredata):
+        self.data = 'im data'
+        self.stuff = 'im stuff'
+        self.moredata = moredata
+    def __eq__(self, other):
+        return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
     
 class BinaryTest(AssertMixin):
     def setUpAll(self):
@@ -140,20 +165,29 @@ class BinaryTest(AssertMixin):
         Column('primary_id', Integer, primary_key=True),
         Column('data', Binary),
         Column('data_slice', Binary(100)),
-        Column('misc', String(30)))
+        Column('misc', String(30)),
+        Column('pickled', PickleType))
         binary_table.create()
     def tearDownAll(self):
         binary_table.drop()
     def testbinary(self):
+        testobj1 = Foo('im foo 1')
+        testobj2 = Foo('im foo 2')
+        
         stream1 =self.get_module_stream('sqlalchemy.sql')
         stream2 =self.get_module_stream('sqlalchemy.engine')
-        binary_table.insert().execute(primary_id=1, misc='sql.pyc',    data=stream1, data_slice=stream1[0:100])
-        binary_table.insert().execute(primary_id=2, misc='engine.pyc', data=stream2, data_slice=stream2[0:99])
+        binary_table.insert().execute(primary_id=1, misc='sql.pyc',    data=stream1, data_slice=stream1[0:100], pickled=testobj1)
+        binary_table.insert().execute(primary_id=2, misc='engine.pyc', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
         l = binary_table.select().execute().fetchall()
+        print type(l[0]['data'])
+        return
         print len(stream1), len(l[0]['data']), len(l[0]['data_slice'])
         self.assert_(list(stream1) == list(l[0]['data']))
         self.assert_(list(stream1[0:100]) == list(l[0]['data_slice']))
         self.assert_(list(stream2) == list(l[1]['data']))
+        self.assert_(testobj1 == l[0]['pickled'])
+        self.assert_(testobj2 == l[1]['pickled'])
+
     def get_module_stream(self, name):
         mod = __import__(name)
         for token in name.split('.')[1:]: