]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Firebird: added Float and Time types (FBFloat and FBTime). Fixed BLOB SUB_TYPE for...
authorRoger Demetrescu <roger.demetrescu@gmail.com>
Fri, 12 Oct 2007 06:02:15 +0000 (06:02 +0000)
committerRoger Demetrescu <roger.demetrescu@gmail.com>
Fri, 12 Oct 2007 06:02:15 +0000 (06:02 +0000)
Firebird's string types are tested in testtypes.py

CHANGES
lib/sqlalchemy/databases/firebird.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index 2187433a6d5d23ca04e52209338d141ced0b1067..53ce0e75eef923378b4eacc4a5a9e282fc91310b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -60,6 +60,10 @@ CHANGES
 
 - PickleType and Interval types (on db not supporting it natively) are now 
   slightly faster.
+  
+- Added Float and Time types to Firebird (FBFloat and FBTime). Fixed
+  BLOB SUB_TYPE for TEXT and Binary types.
+
 
 0.4.0beta6
 ----------
index 22537d7ba8f5b227a865159966ca7bdab138051f..a427c72203c1f1d097945ff79b341e66717c4549 100644 (file)
@@ -5,6 +5,7 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 
+import datetime
 import warnings
 
 from sqlalchemy import util, sql, schema, exceptions, pool
@@ -24,6 +25,29 @@ class FBNumeric(sqltypes.Numeric):
             return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision,
                                                             'length' : self.length }
 
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
+        else:
+            def process(value):
+                if isinstance(value, util.decimal_type):
+                    return float(value)
+                else:
+                    return value
+            return process
+
+
+class FBFloat(sqltypes.Float):
+    def get_col_spec(self):
+        if not self.precision:
+            return "FLOAT"
+        else:
+            return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
+
 class FBInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -38,15 +62,29 @@ class FBDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "TIMESTAMP"
 
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is None or isinstance(value, datetime.datetime):
+                return value
+            else:
+                return datetime.datetime(year=value.year, month=value.month, 
+                    day=value.day)
+        return process
+
 
 class FBDate(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATE"
 
 
+class FBTime(sqltypes.Time):
+    def get_col_spec(self):
+        return "TIME"
+
+
 class FBText(sqltypes.TEXT):
     def get_col_spec(self):
-        return "BLOB SUB_TYPE 2"
+        return "BLOB SUB_TYPE 1"
 
 
 class FBString(sqltypes.String):
@@ -61,7 +99,7 @@ class FBChar(sqltypes.CHAR):
 
 class FBBinary(sqltypes.Binary):
     def get_col_spec(self):
-        return "BLOB SUB_TYPE 1"
+        return "BLOB SUB_TYPE 0"
 
 
 class FBBoolean(sqltypes.Boolean):
@@ -73,9 +111,10 @@ colspecs = {
     sqltypes.Integer : FBInteger,
     sqltypes.Smallinteger : FBSmallInteger,
     sqltypes.Numeric : FBNumeric,
-    sqltypes.Float : FBNumeric,
+    sqltypes.Float : FBFloat,
     sqltypes.DateTime : FBDateTime,
     sqltypes.Date : FBDate,
+    sqltypes.Time : FBTime,
     sqltypes.String : FBString,
     sqltypes.Binary : FBBinary,
     sqltypes.Boolean : FBBoolean,
index c11fa0bdf1c6a19d67edab0bf0557a9077a1e741..71313ec42de149197e7beeabb35ce31b65f9996e 100644 (file)
@@ -4,7 +4,7 @@ import datetime, os
 from sqlalchemy import *
 from sqlalchemy import types
 import sqlalchemy.engine.url as url
-from sqlalchemy.databases import mssql, oracle, mysql, postgres
+from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
 from testlib import *
 
 
@@ -36,16 +36,16 @@ class MyDecoratedType(types.TypeDecorator):
         return process
     def copy(self):
         return MyDecoratedType()
-        
+
 class MyUnicodeType(types.TypeDecorator):
     impl = Unicode
-    
+
     def bind_processor(self, dialect):
         impl_processor = super(MyUnicodeType, self).bind_processor(dialect)
         def process(value):
             return "UNI_BIND_IN"+ impl_processor(value)
         return process
-        
+
     def result_processor(self, dialect):
         impl_processor = super(MyUnicodeType, self).result_processor(dialect)
         def process(value):
@@ -82,16 +82,23 @@ class AdaptTest(PersistTest):
         e1 = url.URL('postgres').get_dialect()()
         e2 = url.URL('mysql').get_dialect()()
         e3 = url.URL('sqlite').get_dialect()()
-        
+        e4 = url.URL('firebird').get_dialect()()
+
         type = String(40)
-        
+
         t1 = type.dialect_impl(e1)
         t2 = type.dialect_impl(e2)
         t3 = type.dialect_impl(e3)
-        assert t1 != t2
-        assert t2 != t3
-        assert t3 != t1
-    
+        t4 = type.dialect_impl(e4)
+
+        impls = [t1, t2, t3, t4]
+        for i,ta in enumerate(impls):
+            for j,tb in enumerate(impls):
+                if i == j:
+                    assert ta == tb  # call me paranoid...  :)
+                else:
+                    assert ta != tb
+
     def testmsnvarchar(self):
         dialect = mssql.MSSQLDialect()
         # run the test twice to insure the caching step works too
@@ -123,14 +130,15 @@ class AdaptTest(PersistTest):
         t2 = mysql.MSVarBinary()
         assert isinstance(dialect.type_descriptor(t1), mysql.MSVarBinary)
         assert isinstance(dialect.type_descriptor(t2), mysql.MSVarBinary)
-    
+
     def teststringadapt(self):
         """test that String with no size becomes TEXT, *all* others stay as varchar/String"""
-        
+
         oracle_dialect = oracle.OracleDialect()
         mysql_dialect = mysql.MySQLDialect()
         postgres_dialect = postgres.PGDialect()
-        
+        firebird_dialect = firebird.FBDialect()
+
         for dialect, start, test in [
             (oracle_dialect, String(), oracle.OracleText),
             (oracle_dialect, VARCHAR(), oracle.OracleString),
@@ -147,11 +155,16 @@ class AdaptTest(PersistTest):
             (postgres_dialect, String(50), postgres.PGString),
             (postgres_dialect, Unicode(), postgres.PGText),
             (postgres_dialect, NCHAR(), postgres.PGString),
+            (firebird_dialect, String(), firebird.FBText),
+            (firebird_dialect, VARCHAR(), firebird.FBString),
+            (firebird_dialect, String(50), firebird.FBString),
+            (firebird_dialect, Unicode(), firebird.FBText),
+            (firebird_dialect, NCHAR(), firebird.FBString),
         ]:
             assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
-        
-        
-        
+
+
+
 class UserDefinedTest(PersistTest):
     """tests user-defined types."""
 
@@ -159,17 +172,17 @@ class UserDefinedTest(PersistTest):
         print users.c.goofy4.type
         print users.c.goofy4.type.dialect_impl(testbase.db.dialect)
         print users.c.goofy4.type.dialect_impl(testbase.db.dialect).get_col_spec()
-        
+
     def testprocessing(self):
 
         global users
         users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4='jack', goofy5='jack', goofy6='jack')
         users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4='lala', goofy5='lala', goofy6='lala')
         users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4='fred', goofy5='fred', goofy6='fred')
-        
+
         l = users.select().execute().fetchall()
         assert l == [
-            (2, 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', 'BIND_INjackBIND_OUT'), 
+            (2, 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', 'BIND_INjackBIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', u'UNI_BIND_INjackUNI_BIND_OUT', 'BIND_INjackBIND_OUT'),
             (3, 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', 'BIND_INlalaBIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT', u'UNI_BIND_INlalaUNI_BIND_OUT', 'BIND_INlalaBIND_OUT'),
             (4, 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', 'BIND_INfredBIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT', u'UNI_BIND_INfredUNI_BIND_OUT', 'BIND_INfredBIND_OUT')
         ]
@@ -181,10 +194,10 @@ class UserDefinedTest(PersistTest):
             Column('user_id', Integer, primary_key = True),
             # totall custom type
             Column('goofy', MyType, nullable = False),
-            
+
             # decorated type with an argument, so its a String
             Column('goofy2', MyDecoratedType(50), nullable = False),
-            
+
             # decorated type without an argument, it will adapt_args to TEXT
             Column('goofy3', MyDecoratedType, nullable = False),
 
@@ -193,9 +206,9 @@ class UserDefinedTest(PersistTest):
             Column('goofy6', LegacyType, nullable = False),
 
         )
-        
+
         metadata.create_all()
-        
+
     def tearDownAll(self):
         metadata.drop_all()
 
@@ -212,7 +225,7 @@ class ColumnsTest(AssertMixin):
         db = testbase.db
         if not db.name=='sqlite' and not db.name=='oracle':
             expectedResults['float_column'] = 'float_column FLOAT(25)'
-    
+
         print db.engine.__module__
         testTable = Table('testColumns', MetaData(db),
             Column('int_column', Integer),
@@ -224,13 +237,13 @@ class ColumnsTest(AssertMixin):
 
         for aCol in testTable.c:
             self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).get_column_specification(aCol))
-        
+
 class UnicodeTest(AssertMixin):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""
     def setUpAll(self):
         global unicode_table
         metadata = MetaData(testbase.db)
-        unicode_table = Table('unicode_table', metadata, 
+        unicode_table = Table('unicode_table', metadata,
             Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
             Column('unicode_varchar', Unicode(250)),
             Column('unicode_text', Unicode),
@@ -239,10 +252,10 @@ class UnicodeTest(AssertMixin):
         unicode_table.create()
     def tearDownAll(self):
         unicode_table.drop()
-    
+
     def tearDown(self):
         unicode_table.delete().execute()
-        
+
     def testbasic(self):
         assert unicode_table.c.unicode_varchar.type.length == 250
         rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
@@ -268,7 +281,7 @@ class UnicodeTest(AssertMixin):
     def testblanks(self):
         unicode_table.insert().execute(unicode_varchar=u'')
         assert select([unicode_table.c.unicode_varchar]).scalar() == u''
-        
+
     def testengineparam(self):
         """tests engine-wide unicode conversion"""
         prev_unicode = testbase.db.engine.dialect.convert_unicode
@@ -294,18 +307,18 @@ class UnicodeTest(AssertMixin):
         """checks the database correctly understands the length of a unicode string"""
         teststr = u'aaa\x1234'
         self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr))
-  
+
 class BinaryTest(AssertMixin):
     def setUpAll(self):
         global binary_table
-        binary_table = Table('binary_table', MetaData(testbase.db), 
+        binary_table = Table('binary_table', MetaData(testbase.db),
         Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
         Column('data', Binary),
         Column('data_slice', Binary(100)),
         Column('misc', String(30)),
         # construct PickleType with non-native pickle module, since cPickle uses relative module
         # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
-       # to the 'types' module
+        # to the 'types' module
         Column('pickled', PickleType)
         )
         binary_table.create()
@@ -325,7 +338,7 @@ class BinaryTest(AssertMixin):
         binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat',    data=stream1, data_slice=stream1[0:100], pickled=testobj1)
         binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
         binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None)
-        
+
         for stmt in (
             binary_table.select(order_by=binary_table.c.primary_id),
             text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db)
@@ -343,8 +356,8 @@ class BinaryTest(AssertMixin):
         f = os.path.join(os.path.dirname(testbase.__file__), name)
         # put a number less than the typical MySQL default BLOB size
         return file(f).read(len)
-    
-    
+
+
 class DateTest(AssertMixin):
     def setUpAll(self):
         global users_with_date, insert_data
@@ -380,9 +393,9 @@ class DateTest(AssertMixin):
             time_micro = 999
 
             # Missing or poor microsecond support:
-            if db.engine.name in ('mssql', 'mysql'):
+            if db.engine.name in ('mssql', 'mysql', 'firebird'):
                 datetime_micro, time_micro = 0, 0
-            
+
             insert_data =  [
                 [7, 'jack',
                  datetime.datetime(2005, 11, 10, 0, 0),
@@ -406,7 +419,7 @@ class DateTest(AssertMixin):
                        Column('user_datetime', DateTime(timezone=False)),
                        Column('user_date', Date),
                        Column('user_time', Time)]
+
         users_with_date = Table('query_users_with_date',
                                 MetaData(testbase.db), *collist)
         users_with_date.create()
@@ -424,15 +437,15 @@ class DateTest(AssertMixin):
         l = map(list, users_with_date.select().execute().fetchall())
         self.assert_(l == insert_data,
                      'DateTest mismatch: got:%s expected:%s' % (l, insert_data))
-        
+
     def testtextdate(self):
         x = testbase.db.text(
             "select user_datetime from query_users_with_date",
             typemap={'user_datetime':DateTime}).execute().fetchall()
-        
+
         print repr(x)
         self.assert_(isinstance(x[0][0], datetime.datetime))
-        
+
         x = testbase.db.text(
             "select * from query_users_with_date where user_datetime=:somedate",
             bindparams=[bindparam('somedate', type_=types.DateTime)]).execute(
@@ -473,13 +486,13 @@ class NumericTest(AssertMixin):
             Column('fcasdec', Float(asdecimal=True))
         )
         metadata.create_all()
-        
+
     def tearDownAll(self):
         metadata.drop_all()
-        
+
     def tearDown(self):
         numeric_table.delete().execute()
-        
+
     def test_decimal(self):
         from decimal import Decimal
         numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78)
@@ -494,24 +507,24 @@ class NumericTest(AssertMixin):
             (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
             (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
         ]
-        
-            
+
+
 class IntervalTest(AssertMixin):
     def setUpAll(self):
         global interval_table, metadata
         metadata = MetaData(testbase.db)
-        interval_table = Table("intervaltable", metadata, 
+        interval_table = Table("intervaltable", metadata,
             Column("id", Integer, Sequence('interval_id_seq', optional=True), primary_key=True),
             Column("interval", Interval),
             )
         metadata.create_all()
-    
+
     def tearDown(self):
         interval_table.delete().execute()
-            
+
     def tearDownAll(self):
         metadata.drop_all()
-        
+
     def test_roundtrip(self):
         delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17)
         interval_table.insert().execute(interval=delta)
@@ -520,12 +533,12 @@ class IntervalTest(AssertMixin):
     def test_null(self):
         interval_table.insert().execute(id=1, inverval=None)
         assert interval_table.select().execute().fetchone()['interval'] is None
-        
+
 class BooleanTest(AssertMixin):
     def setUpAll(self):
         global bool_table
         metadata = MetaData(testbase.db)
-        bool_table = Table('booltest', metadata, 
+        bool_table = Table('booltest', metadata,
             Column('id', Integer, primary_key=True),
             Column('value', Boolean))
         bool_table.create()
@@ -537,11 +550,11 @@ class BooleanTest(AssertMixin):
         bool_table.insert().execute(id=3, value=True)
         bool_table.insert().execute(id=4, value=True)
         bool_table.insert().execute(id=5, value=True)
-        
+
         res = bool_table.select(bool_table.c.value==True).execute().fetchall()
         print res
         assert(res==[(1, True),(3, True),(4, True),(5, True)])
-        
+
         res2 = bool_table.select(bool_table.c.value==False).execute().fetchall()
         print res2
         assert(res2==[(2, False)])