]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implements the boolean type for FB
authorLele Gaifax <lele@metapensiero.it>
Sun, 12 Jul 2009 14:08:52 +0000 (14:08 +0000)
committerLele Gaifax <lele@metapensiero.it>
Sun, 12 Jul 2009 14:08:52 +0000 (14:08 +0000)
Also, on reflection restore the heuristic needed to find the sequence linked to the PK.

lib/sqlalchemy/dialects/firebird/base.py

index ad352ecf2b83be1d51a0a7761c996ef769c9234d..3ba14b5ff5e80912027228df8eacba33197bf566 100644 (file)
@@ -125,7 +125,30 @@ RESERVED_WORDS = set(
     "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
     "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
 
+
+class _FBBoolean(sqltypes.Boolean):
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+
+
 colspecs = {
+    sqltypes.Boolean: _FBBoolean,
 }
 
 ischema_names = {
@@ -145,10 +168,13 @@ ischema_names = {
     }
 
 
-# TODO: Boolean type, date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
+# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
 # as bind/result functionality is required)
 
 class FBTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_boolean(self, type_):
+        return self.visit_SMALLINT(type_)
+
     def visit_datetime(self, type_):
         return self.visit_TIMESTAMP(type_)
 
@@ -439,8 +465,7 @@ class FBDialect(default.DefaultDialect):
         return pkfields
 
     @reflection.cache
-    def get_column_sequence(self, connection, table_name, column_name,
-                                                        schema=None, **kw):
+    def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw):
         tablename = self.denormalize_name(table_name)
         colname = self.denormalize_name(column_name)
         # Heuristic-query to determine the generator associated to a PK field
@@ -484,6 +509,9 @@ class FBDialect(default.DefaultDialect):
         WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
         ORDER BY r.rdb$field_position
         """
+        # get the PK, used to determine the eventual associated sequence
+        pkey_cols = self.get_primary_keys(connection, table_name)
+
         tablename = self.denormalize_name(table_name)
         # get all of the fields for this table
         c = connection.execute(tblqry, [tablename])
@@ -527,6 +555,14 @@ class FBDialect(default.DefaultDialect):
                 'nullable' :  not bool(row['null_flag']),
                 'default' : defvalue
             }
+
+            # if the PK is a single field, try to see if its linked to
+            # a sequence thru a trigger
+            if len(pkey_cols)==1 and name==pkey_cols[0]:
+                seq_d = self.get_column_sequence(connection, tablename, name)
+                if seq_d is not None:
+                    col_d['sequence'] = seq_d
+
             cols.append(col_d)
         return cols