]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more work on types. this is the simplest implementation which is a little more manual
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Apr 2006 16:12:02 +0000 (16:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Apr 2006 16:12:02 +0000 (16:12 +0000)
lib/sqlalchemy/types.py
test/testtypes.py

index 2b80fc63ec9e93b6095594dfefc360d503953f0b..40b94bbd8e924d94730e9d71fd7ae8081d187af5 100644 (file)
@@ -12,6 +12,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
             ]
 
 import sqlalchemy.util as util
+import sqlalchemy.exceptions as exceptions
 try:
     import cPickle as pickle
 except:
@@ -25,11 +26,7 @@ class AbstractType(object):
             self._impl_dict = {}
             return self._impl_dict
     impl_dict = property(_get_impl_dict)
-    def get_constructor_args(self):
-        return {}
-    def adapt_args(self):
-        return self
-
+            
 class TypeEngine(AbstractType):
     def __init__(self, *args, **params):
         pass
@@ -44,16 +41,22 @@ class TypeEngine(AbstractType):
         return value
     def convert_result_value(self, value, engine):
         return value
+    def adapt(self, cls):
+        return cls()
+
+AbstractType.impl = TypeEngine
 
 class TypeDecorator(AbstractType):
-    def __init__(self, *args, **params):
-        pass
+    def __init__(self, *args, **kwargs):
+        self.impl = self.__class__.impl(*args, **kwargs)
     def engine_impl(self, engine):
         try:
             return self.impl_dict[engine]
         except:
             typedesc = engine.type_descriptor(self.impl)
-            tt = self.__class__(**self.get_constructor_args())
+            tt = self.copy()
+            if not isinstance(tt, self.__class__):
+                raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
             tt.impl = typedesc
             self.impl_dict[engine] = tt
             return tt
@@ -63,6 +66,8 @@ class TypeDecorator(AbstractType):
         return self.impl.convert_bind_param(value, engine)
     def convert_result_value(self, value, engine):
         return self.impl.convert_result_value(value, engine)
+    def copy(self):
+        raise NotImplementedError()
         
 def to_instance(typeobj):
     if typeobj is None:
@@ -74,8 +79,7 @@ def to_instance(typeobj):
 def adapt_type(typeobj, colspecs):
     if isinstance(typeobj, type):
         typeobj = typeobj()
-    t2 = typeobj.adapt_args()
-    for t in t2.__class__.__mro__[0:-1]:
+    for t in typeobj.__class__.__mro__[0:-1]:
         try:
             impltype = colspecs[t]
             break
@@ -84,7 +88,7 @@ def adapt_type(typeobj, colspecs):
     else:
         # couldnt adapt...raise exception ?
         return typeobj
-    return impltype(**t2.get_constructor_args())
+    return typeobj.adapt(impltype)
     
 class NullTypeEngine(TypeEngine):
     def get_col_spec(self):
@@ -96,10 +100,15 @@ class NullTypeEngine(TypeEngine):
 
     
 class String(TypeEngine):
+    def __new__(cls, *args, **kwargs):
+        if cls is not String or len(args) > 0 or kwargs.has_key('length'):
+            return super(String, cls).__new__(cls, *args, **kwargs)
+        else:
+            return super(String, TEXT).__new__(TEXT, *args, **kwargs)
     def __init__(self, length = None):
         self.length = length
-    def get_constructor_args(self):
-        return {'length':self.length}
+    def adapt(self, impltype):
+        return impltype(length=self.length)
     def convert_bind_param(self, value, engine):
         if not engine.convert_unicode or value is None or not isinstance(value, unicode):
             return value
@@ -110,11 +119,6 @@ class String(TypeEngine):
             return value
         else:
             return value.decode(engine.encoding)
-    def adapt_args(self):
-        if self.length is None:
-            return TEXT()
-        else:
-            return self
             
 class Unicode(TypeDecorator):
     impl = String
@@ -128,7 +132,9 @@ class Unicode(TypeDecorator):
              return value.decode(engine.encoding)
          else:
              return value
-              
+    def copy(self):
+        return Unicode(self.impl.length)
+        
 class Integer(TypeEngine):
     """integer datatype"""
     pass
@@ -142,22 +148,25 @@ class Numeric(TypeEngine):
     def __init__(self, precision = 10, length = 2):
         self.precision = precision
         self.length = length
-    def get_constructor_args(self):
-        return {'precision':self.precision, 'length':self.length}
+    def adapt(self, impltype):
+        return impltype(precision=self.precision, length=self.length)
 
 class Float(Numeric):
     def __init__(self, precision = 10):
         self.precision = precision
-    def get_constructor_args(self):
-        return {'precision':self.precision}
+    def adapt(self, impltype):
+        return impltype(precision=self.precision)
 
 class DateTime(TypeEngine):
+    """implements a type for datetime.datetime() objects"""
     pass
 
 class Date(TypeEngine):
+    """implements a type for datetime.date() objects"""
     pass
 
 class Time(TypeEngine):
+    """implements a type for datetime.time() objects"""
     pass
 
 class Binary(TypeEngine):
@@ -167,25 +176,26 @@ class Binary(TypeEngine):
         return engine.dbapi().Binary(value)
     def convert_result_value(self, value, engine):
         return value
-    def get_constructor_args(self):
-        return {'length':self.length}
+    def adap(self, impltype):
+        return impltype(length=self.length)
 
 class PickleType(TypeDecorator):
-      def __init__(self, protocol=pickle.HIGHEST_PROTOCOL):
-           """allows the pickle protocol to be specified"""
-           self.protocol = protocol
-           self.impl = Binary()
-      def convert_result_value(self, value, engine):
-          if value is None:
-              return None
-          buf = self.impl.convert_result_value(value, engine)
-          return pickle.loads(str(buf))
-      def convert_bind_param(self, value, engine):
-          if value is None:
-              return None
-          return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine)
-      def get_constructor_args(self):
-            return {}
+    impl = Binary
+    def __init__(self, protocol=pickle.HIGHEST_PROTOCOL):
+       """allows the pickle protocol to be specified"""
+       self.protocol = protocol
+       super(PickleType, self).__init__()
+    def convert_result_value(self, value, engine):
+      if value is None:
+          return None
+      buf = self.impl.convert_result_value(value, engine)
+      return pickle.loads(str(buf))
+    def convert_bind_param(self, value, engine):
+      if value is None:
+          return None
+      return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine)
+    def copy(self):
+        return PickleType(self.protocol)
 
 class Boolean(TypeEngine):
     pass
index 3903b3a816406ed8acedde7d2e219a310a6eedd1..2fb41cc7943a9fa5212b2b6126a411c38723c602 100644 (file)
@@ -14,8 +14,6 @@ class MyType(types.TypeEngine):
         return value + "BIND_OUT"
     def adapt(self, typeobj):
         return typeobj()
-    def adapt_args(self):
-        return self
 
 class MyDecoratedType(types.TypeDecorator):
     impl = String
@@ -23,12 +21,16 @@ class MyDecoratedType(types.TypeDecorator):
         return "BIND_IN"+ value
     def convert_result_value(self, value, engine):
         return value + "BIND_OUT"
-
+    def copy(self):
+        return MyDecoratedType()
+        
 class MyUnicodeType(types.Unicode):
     def convert_bind_param(self, value, engine):
         return "UNI_BIND_IN"+ value
     def convert_result_value(self, value, engine):
         return value + "UNI_BIND_OUT"
+    def copy(self):
+        return MyUnicodeType(self.impl.length)
 
 class AdaptTest(PersistTest):
     def testadapt(self):
@@ -44,6 +46,15 @@ class AdaptTest(PersistTest):
         assert t1 != t2
         assert t2 != t3
         assert t3 != t1
+    
+    def testdecorator(self):
+        t1 = Unicode(20)
+        t2 = Unicode()
+        assert isinstance(t1.impl, String)
+        assert not isinstance(t1.impl, TEXT)
+        assert (t1.impl.length == 20)
+        assert isinstance(t2.impl, TEXT)
+        assert t2.impl.length is None
         
 class OverrideTest(PersistTest):
     """tests user-defined types, including a full type as well as a TypeDecorator"""