From 337d3b268562f421b6bb0e445c6f2ba187514a4b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 29 Apr 2006 16:12:02 +0000 Subject: [PATCH] more work on types. this is the simplest implementation which is a little more manual --- lib/sqlalchemy/types.py | 90 +++++++++++++++++++++++------------------ test/testtypes.py | 17 ++++++-- 2 files changed, 64 insertions(+), 43 deletions(-) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 2b80fc63ec..40b94bbd8e 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -12,6 +12,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine', ] import sqlalchemy.util as util +import sqlalchemy.exceptions as exceptions try: import cPickle as pickle except: @@ -25,11 +26,7 @@ class AbstractType(object): self._impl_dict = {} return self._impl_dict impl_dict = property(_get_impl_dict) - def get_constructor_args(self): - return {} - def adapt_args(self): - return self - + class TypeEngine(AbstractType): def __init__(self, *args, **params): pass @@ -44,16 +41,22 @@ class TypeEngine(AbstractType): return value def convert_result_value(self, value, engine): return value + def adapt(self, cls): + return cls() + +AbstractType.impl = TypeEngine class TypeDecorator(AbstractType): - def __init__(self, *args, **params): - pass + def __init__(self, *args, **kwargs): + self.impl = self.__class__.impl(*args, **kwargs) def engine_impl(self, engine): try: return self.impl_dict[engine] except: typedesc = engine.type_descriptor(self.impl) - tt = self.__class__(**self.get_constructor_args()) + tt = self.copy() + if not isinstance(tt, self.__class__): + raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) tt.impl = typedesc self.impl_dict[engine] = tt return tt @@ -63,6 +66,8 @@ class TypeDecorator(AbstractType): return self.impl.convert_bind_param(value, engine) def convert_result_value(self, value, engine): return self.impl.convert_result_value(value, engine) + def copy(self): + raise NotImplementedError() def to_instance(typeobj): if typeobj is None: @@ -74,8 +79,7 @@ def to_instance(typeobj): def adapt_type(typeobj, colspecs): if isinstance(typeobj, type): typeobj = typeobj() - t2 = typeobj.adapt_args() - for t in t2.__class__.__mro__[0:-1]: + for t in typeobj.__class__.__mro__[0:-1]: try: impltype = colspecs[t] break @@ -84,7 +88,7 @@ def adapt_type(typeobj, colspecs): else: # couldnt adapt...raise exception ? return typeobj - return impltype(**t2.get_constructor_args()) + return typeobj.adapt(impltype) class NullTypeEngine(TypeEngine): def get_col_spec(self): @@ -96,10 +100,15 @@ class NullTypeEngine(TypeEngine): class String(TypeEngine): + def __new__(cls, *args, **kwargs): + if cls is not String or len(args) > 0 or kwargs.has_key('length'): + return super(String, cls).__new__(cls, *args, **kwargs) + else: + return super(String, TEXT).__new__(TEXT, *args, **kwargs) def __init__(self, length = None): self.length = length - def get_constructor_args(self): - return {'length':self.length} + def adapt(self, impltype): + return impltype(length=self.length) def convert_bind_param(self, value, engine): if not engine.convert_unicode or value is None or not isinstance(value, unicode): return value @@ -110,11 +119,6 @@ class String(TypeEngine): return value else: return value.decode(engine.encoding) - def adapt_args(self): - if self.length is None: - return TEXT() - else: - return self class Unicode(TypeDecorator): impl = String @@ -128,7 +132,9 @@ class Unicode(TypeDecorator): return value.decode(engine.encoding) else: return value - + def copy(self): + return Unicode(self.impl.length) + class Integer(TypeEngine): """integer datatype""" pass @@ -142,22 +148,25 @@ class Numeric(TypeEngine): def __init__(self, precision = 10, length = 2): self.precision = precision self.length = length - def get_constructor_args(self): - return {'precision':self.precision, 'length':self.length} + def adapt(self, impltype): + return impltype(precision=self.precision, length=self.length) class Float(Numeric): def __init__(self, precision = 10): self.precision = precision - def get_constructor_args(self): - return {'precision':self.precision} + def adapt(self, impltype): + return impltype(precision=self.precision) class DateTime(TypeEngine): + """implements a type for datetime.datetime() objects""" pass class Date(TypeEngine): + """implements a type for datetime.date() objects""" pass class Time(TypeEngine): + """implements a type for datetime.time() objects""" pass class Binary(TypeEngine): @@ -167,25 +176,26 @@ class Binary(TypeEngine): return engine.dbapi().Binary(value) def convert_result_value(self, value, engine): return value - def get_constructor_args(self): - return {'length':self.length} + def adap(self, impltype): + return impltype(length=self.length) class PickleType(TypeDecorator): - def __init__(self, protocol=pickle.HIGHEST_PROTOCOL): - """allows the pickle protocol to be specified""" - self.protocol = protocol - self.impl = Binary() - def convert_result_value(self, value, engine): - if value is None: - return None - buf = self.impl.convert_result_value(value, engine) - return pickle.loads(str(buf)) - def convert_bind_param(self, value, engine): - if value is None: - return None - return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine) - def get_constructor_args(self): - return {} + impl = Binary + def __init__(self, protocol=pickle.HIGHEST_PROTOCOL): + """allows the pickle protocol to be specified""" + self.protocol = protocol + super(PickleType, self).__init__() + def convert_result_value(self, value, engine): + if value is None: + return None + buf = self.impl.convert_result_value(value, engine) + return pickle.loads(str(buf)) + def convert_bind_param(self, value, engine): + if value is None: + return None + return self.impl.convert_bind_param(pickle.dumps(value, self.protocol), engine) + def copy(self): + return PickleType(self.protocol) class Boolean(TypeEngine): pass diff --git a/test/testtypes.py b/test/testtypes.py index 3903b3a816..2fb41cc794 100644 --- a/test/testtypes.py +++ b/test/testtypes.py @@ -14,8 +14,6 @@ class MyType(types.TypeEngine): return value + "BIND_OUT" def adapt(self, typeobj): return typeobj() - def adapt_args(self): - return self class MyDecoratedType(types.TypeDecorator): impl = String @@ -23,12 +21,16 @@ class MyDecoratedType(types.TypeDecorator): return "BIND_IN"+ value def convert_result_value(self, value, engine): return value + "BIND_OUT" - + def copy(self): + return MyDecoratedType() + class MyUnicodeType(types.Unicode): def convert_bind_param(self, value, engine): return "UNI_BIND_IN"+ value def convert_result_value(self, value, engine): return value + "UNI_BIND_OUT" + def copy(self): + return MyUnicodeType(self.impl.length) class AdaptTest(PersistTest): def testadapt(self): @@ -44,6 +46,15 @@ class AdaptTest(PersistTest): assert t1 != t2 assert t2 != t3 assert t3 != t1 + + def testdecorator(self): + t1 = Unicode(20) + t2 = Unicode() + assert isinstance(t1.impl, String) + assert not isinstance(t1.impl, TEXT) + assert (t1.impl.length == 20) + assert isinstance(t2.impl, TEXT) + assert t2.impl.length is None class OverrideTest(PersistTest): """tests user-defined types, including a full type as well as a TypeDecorator""" -- 2.47.2