]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
beefed up type adaptation methodology, got Unicode to do encode/decode + test case
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Feb 2006 18:54:51 +0000 (18:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Feb 2006 18:54:51 +0000 (18:54 +0000)
lib/sqlalchemy/types.py
test/types.py

index ce1cec778aa87bde61a036cf330acf15b5a9c43e..230107554258c7373844df02527d00879367ce11 100644 (file)
@@ -8,7 +8,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
             'INT', 'CHAR', 'VARCHAR', 'TEXT', 'FLOAT', 'DECIMAL', 
             'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'String', 'Integer', 'Smallinteger',
             'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'NULLTYPE',
-               'SMALLINT', 'DATE', 'TIME'
+        'SMALLINT', 'DATE', 'TIME'
             ]
 
 import sqlalchemy.util as util
@@ -21,9 +21,20 @@ class TypeEngine(object):
     def convert_result_value(self, value, engine):
         raise NotImplementedError()
     def adapt(self, typeobj):
+        """given a class that is a subclass of this TypeEngine's class, produces a new
+        instance of that class with an equivalent state to this TypeEngine.  The given
+        class is a database-specific subclass which is obtained via a lookup dictionary,
+        mapped against the class returned by the class_to_adapt() method."""
         return typeobj()
     def adapt_args(self):
+        """Returns an instance of this TypeEngine instance's class, adapted according
+        to the constructor arguments of this TypeEngine.  Default return value is 
+        just this object instance."""
         return self
+    def class_to_adapt(self):
+        """returns the class that should be sent to the adapt() method.  This class
+        will be used to lookup an approprate database-specific subclass."""
+        return self.__class__
     def __repr__(self):
         return util.generic_repr(self)
         
@@ -38,7 +49,7 @@ def adapt_type(typeobj, colspecs):
     if typeobj.__module__ != 'sqlalchemy.types' or typeobj.__class__==NullTypeEngine:
         return typeobj
     typeobj = typeobj.adapt_args()
-    t = typeobj.__class__
+    t = typeobj.class_to_adapt()
     for t in t.__mro__[0:-1]:
         try:
             return typeobj.adapt(colspecs[t])
@@ -60,30 +71,40 @@ class TypeDecorator(object):
     def get_col_spec(self):
         return self.extended.get_col_spec()
     def adapt(self, typeobj):
+        self.extended = self.extended.adapt(typeobj)
+        return self
+    def adapt_args(self):
         t = self.__class__.__mro__[2]
-        print repr(t)
-        c = self.__class__()
-        c.extended = t.adapt(self, typeobj)
-        return c
+        self.extended = t.adapt_args(self)
+        return self
+    def class_to_adapt(self):
+        return self.extended.__class__
     
 class String(NullTypeEngine):
-    def __init__(self, length = None, is_unicode=False):
+    def __init__(self, length = None):
         self.length = length
-        self.is_unicode = is_unicode
     def adapt(self, typeobj):
-        return typeobj(self.length, self.is_unicode)
+        return typeobj(self.length)
     def adapt_args(self):
         if self.length is None:
-            return TEXT(is_unicode=self.is_unicode)
+            return TEXT()
         else:
             return self
 
-class Unicode(String):
+class Unicode(TypeDecorator,String):
     def __init__(self, length=None):
-        String.__init__(self, length, is_unicode=True)
-    def adapt(self, typeobj):
-        return typeobj(self.length, is_unicode=True)
-        
+        String.__init__(self, length)
+    def convert_bind_param(self, value, engine):
+         if isinstance(value, unicode):
+              return value.encode('utf-8')
+         else:
+              return value
+    def convert_result_value(self, value, engine):
+         if not isinstance(value, unicode):
+             return value.decode('utf-8')
+         else:
+             return value
+              
 class Integer(NullTypeEngine):
     """integer datatype"""
     # TODO: do string bind params need int(value) performed before sending ?  
index f7b30fd2c152a49e79d6d7a9678aadee12fa82c7..4de4daa4210cfe29f88f92839c2c7ce9d52c4be7 100644 (file)
@@ -1,5 +1,5 @@
 from sqlalchemy import *
-import string,datetime, re
+import string,datetime, re, sys
 from testbase import PersistTest, AssertMixin
 import testbase
     
@@ -70,7 +70,33 @@ class ColumnsTest(AssertMixin):
         for aCol in testTable.c:
             self.assertEquals(expectedResults[aCol.name], db.schemagenerator(None).get_column_specification(aCol))
         
-
+class UnicodeTest(AssertMixin):
+    def setUpAll(self):
+        global unicode_table
+        unicode_table = Table('unicode_table', db, 
+            Column('id', Integer, primary_key=True),
+            Column('unicode_data', Unicode),
+            Column('plain_data', String)
+            )
+        unicode_table.create()
+    def tearDownAll(self):
+        unicode_table.drop()
+    def testbasic(self):
+        rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
+        unicodedata = rawdata.decode('utf-8')
+        unicode_table.insert().execute(unicode_data=unicodedata, plain_data=rawdata)
+        x = unicode_table.select().execute().fetchone()
+        self.echo(repr(x['unicode_data']))
+        self.echo(repr(x['plain_data']))
+        self.assert_(isinstance(x['unicode_data'], unicode) and x['unicode_data'] == unicodedata)
+        if isinstance(x['plain_data'], unicode):
+            # SQLLite returns even non-unicode data as unicode
+            self.assert_(sys.modules[db.engine.__module__].descriptor()['name'] == 'sqlite')
+            self.echo("its sqlite !")
+        else:
+            self.assert_(not isinstance(x['plain_data'], unicode) and x['plain_data'] == rawdata)
+            
+    
 class BinaryTest(AssertMixin):
     def setUpAll(self):
         global binary_table