]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
firebird module initial checkin
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 17:41:19 +0000 (17:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 17:41:19 +0000 (17:41 +0000)
lib/sqlalchemy/databases/firebird.py [new file with mode: 0644]

diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
new file mode 100644 (file)
index 0000000..2509148
--- /dev/null
@@ -0,0 +1,282 @@
+# firebird.py
+# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+import sys, StringIO, string
+
+import sqlalchemy.sql as sql
+import sqlalchemy.schema as schema
+import sqlalchemy.ansisql as ansisql
+from sqlalchemy import *
+import sqlalchemy.types as sqltypes
+
+try:
+    import kinterbasdb
+except:
+    kinterbasdb = None
+        
+class FBNumeric(sqltypes.Numeric):
+    def get_col_spec(self):
+        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+class FBInteger(sqltypes.Integer):
+    def get_col_spec(self):
+        return "INTEGER"
+class FBSmallInteger(sqltypes.Smallinteger):
+    def get_col_spec(self):
+        return "SMALLINT"
+class FBDateTime(sqltypes.DateTime):
+    def get_col_spec(self):
+        return "DATE"
+class FBText(sqltypes.TEXT):
+    def get_col_spec(self):
+        return "BLOB SUB_TYPE 2"
+class FBString(sqltypes.String):
+    def get_col_spec(self):
+        return "VARCHAR(%(length)s)" % {'length' : self.length}
+class FBChar(sqltypes.CHAR):
+    def get_col_spec(self):
+        return "CHAR(%(length)s)" % {'length' : self.length}
+class FBBinary(sqltypes.Binary):
+    def get_col_spec(self):
+        return "BLOB SUB_TYPE 1"
+class FBBoolean(sqltypes.Boolean):
+    def get_col_spec(self):
+        return "SMALLINT"
+        
+colspecs = {
+    sqltypes.Integer : FBInteger,
+    sqltypes.Smallinteger : FBSmallInteger,
+    sqltypes.Numeric : FBNumeric,
+    sqltypes.Float : FBNumeric,
+    sqltypes.DateTime : FBDateTime,
+    sqltypes.Date : FBDateTime,
+    sqltypes.String : FBString,
+    sqltypes.Binary : FBBinary,
+    sqltypes.Boolean : FBBoolean,
+    sqltypes.TEXT : FBText,
+    sqltypes.CHAR: FBChar,
+}
+
+def engine(*args, **params):
+    return FBSQLEngine(*args, **params)
+
+def descriptor():
+    return {'name':'firebird',
+    'description':'Firebird',
+    'arguments':[
+        ('host', 'Host Server Name', None),
+        ('database', 'Database Name', None),
+        ('user', 'Username', None),
+        ('password', 'Password', None)
+    ]}
+    
+class FBSQLEngine(ansisql.ANSISQLEngine):
+    def __init__(self, opts, use_ansi = True, module = None, **params):
+        self._use_ansi = use_ansi
+        self.opts = opts or {}
+        if module is None:
+            self.module = kinterbasdb
+        else:
+            self.module = module
+        ansisql.ANSISQLEngine.__init__(self, **params)
+
+    def do_commit(self, connection):
+        connection.commit(True)
+
+    def do_rollback(self, connection):
+        connection.rollback(True)
+
+    def dbapi(self):
+        return self.module
+
+    def connect_args(self):
+        return [[], self.opts]
+        
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
+
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
+    def compiler(self, statement, bindparams, **kwargs):
+        return FBCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs)
+
+    def schemagenerator(self, **params):
+        return FBSchemaGenerator(self, **params)
+    
+    def schemadropper(self, **params):
+        return FBSchemaDropper(self, **params)
+    def defaultrunner(self, proxy):
+        return FBDefaultRunner(self, proxy)
+        
+    def reflecttable(self, table):
+        #TODO: map these better
+        column_func = {
+            14 : lambda r: String(r['FLEN']), # TEXT
+            7  : lambda r: Integer(), # SHORT
+            8  : lambda r: Integer(), # LONG
+            9  : lambda r: Float(), # QUAD
+            10 : lambda r: Float(), # FLOAT
+            27 : lambda r: Double(), # DOUBLE
+            35 : lambda r: DateTime(), # TIMESTAMP
+            37 : lambda r: String(r['FLEN']), # VARYING
+            261: lambda r: TEXT(), # BLOB
+            40 : lambda r: Char(r['FLEN']), # CSTRING
+            12 : lambda r: Date(), # DATE
+            13 : lambda r: Time(), # TIME
+            16 : lambda r: Numeric(precision=r['FPREC'], length=r['FSCALE'] * -1)  #INT64
+            }
+        tblqry = """\
+        SELECT DISTINCT R.RDB$FIELD_NAME AS FNAME,
+                  R.RDB$NULL_FLAG AS NULL_FLAG,
+                  R.RDB$FIELD_POSITION,
+                  F.RDB$FIELD_TYPE AS FTYPE,
+                  F.RDB$FIELD_SUB_TYPE AS STYPE,
+                  F.RDB$FIELD_LENGTH AS FLEN,
+                  F.RDB$FIELD_PRECISION AS FPREC,
+                  F.RDB$FIELD_SCALE AS FSCALE
+        FROM RDB$RELATION_FIELDS R 
+             JOIN RDB$FIELDS F ON R.RDB$FIELD_SOURCE=F.RDB$FIELD_NAME
+        WHERE F.RDB$SYSTEM_FLAG=0 and R.RDB$RELATION_NAME=?
+        ORDER BY R.RDB$FIELD_POSITION;"""
+        keyqry = """
+        SELECT RC.RDB$CONSTRAINT_TYPE KEYTYPE,
+               RC.RDB$CONSTRAINT_NAME CNAME,
+               RC.RDB$INDEX_NAME INAME,
+               SE.RDB$FIELD_NAME SENAME,
+        FROM RDB$RELATION_CONSTRAINTS RC
+            LEFT JOIN RDB$INDEX_SEGMENTS SE
+              ON RC.RDB$INDEX_NAME=SE.RDB$INDEX_NAME
+        WHERE RC.RDB$RELATION_NAME=? AND SE.RDB$FIELD_NAME=?
+        """
+    
+        #import pdb;pdb.set_trace()
+        # get all of the fields for this table
+        c = self.execute(tblqry, [table.name.upper()])
+        while True:
+            row = c.fetchone()
+            if not row: break
+            args = [row['FNAME']]
+            kw = {}
+            # get the data types and lengths
+            args.append(column_func[row['FTYPE']](row))
+
+            # is it a foreign key (and what is it linked to)
+
+            # is it a primary key?
+
+            table.append_item(Column(*args, **kw))
+            # does the field have indexes
+
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
+    def pre_exec(self, proxy, compiled, parameters, **kwargs):
+        # this is just an assertion that all the primary key columns in an insert statement
+        # have a value set up, or have a default generator ready to go
+        if getattr(compiled, "isinsert", False):
+            if isinstance(parameters, list):
+                plist = parameters
+            else:
+                plist = [parameters]
+            for param in plist:
+                for primary_key in compiled.statement.table.primary_key:
+                    if not param.has_key(primary_key.key) or param[primary_key.key] is None:
+                        if primary_key.default is None:
+                            raise "Column '%s.%s': Firebird primary key columns require a default value or a schema.Sequence to create ids" % (primary_key.table.name, primary_key.name)
+
+    def _executemany(self, c, statement, parameters):
+        rowcount = 0
+        for param in parameters:
+            c.execute(statement, param)
+            rowcount += c.rowcount
+        self.context.rowcount = rowcount
+
+class FBCompiler(ansisql.ANSICompiler):
+    """firebird compiler modifies the lexical structure of Select statements to work under 
+    non-ANSI configured Firebird databases, if the use_ansi flag is False."""
+    
+    def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs):
+        self._outertable = None
+        self._use_ansi = use_ansi
+        ansisql.ANSICompiler.__init__(self, engine, statement, parameters, **kwargs)
+      
+    def visit_column(self, column):
+        if self._use_ansi:
+            return ansisql.ANSICompiler.visit_column(self, column)
+            
+        if column.table is self._outertable:
+            self.strings[column] = "%s.%s(+)" % (column.table.name, column.name)
+        else:
+            self.strings[column] = "%s.%s" % (column.table.name, column.name)
+       
+    def visit_function(self, func):
+        if len(func.clauses):
+            super(FBCompiler, self).visit_function(func)
+        else:
+            self.strings[func] = func.name
+    def visit_insert(self, insert):
+        """inserts are required to have the primary keys be explicitly present.
+         mapper will by default not put them in the insert statement to comply
+         with autoincrement fields that require they not be present.  so, 
+         put them all in for all primary key columns."""
+        for c in insert.table.primary_key:
+            if not self.parameters.has_key(c.key):
+                self.parameters[c.key] = None
+        return ansisql.ANSICompiler.visit_insert(self, insert)
+           
+    def visit_select_precolumns(self, select):
+        """ called when building a SELECT statment, position is just before column list 
+        Firebird puts the limit and offset right after the select...thanks for adding the
+        visit_select_precolumns!!!"""
+        if select.offset:
+            result +=" FIRST " + select.offset
+        if select.limit:
+            result += " SKIP " + select.limit
+        if select.distinct:
+            result += " DISTINCT "
+        return result
+
+    def limit_clause(self, select):
+        """Already taken care of in the visit_select_precolumns method."""
+        return ""
+
+class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
+    def get_column_specification(self, column, override_pk=False, **kwargs):
+        colspec = column.name 
+        colspec += " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        if column.primary_key and not override_pk:
+            colspec += " PRIMARY KEY"
+        if column.foreign_key:
+            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) 
+        return colspec
+
+    def visit_sequence(self, sequence):
+        self.append("CREATE GENERATOR %s" % sequence.name)
+        self.execute()
+
+class FBSchemaDropper(ansisql.ANSISchemaDropper):
+    def visit_sequence(self, sequence):
+        self.append("DROP GENERATOR %s" % sequence.name)
+        self.execute()
+
+class FBDefaultRunner(ansisql.ANSIDefaultRunner):
+    def exec_default_sql(self, default):
+        c = sql.select([default.arg], from_obj=["rdb$database"], engine=self.engine).compile()
+        return self.proxy(str(c), c.get_params()).fetchone()[0]
+    
+    def visit_sequence(self, seq):
+        return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0]
+
+