]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
begin modernize of informix dialect for [ticket:1499]
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Dec 2009 17:45:19 +0000 (17:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Dec 2009 17:45:19 +0000 (17:45 +0000)
lib/sqlalchemy/dialects/informix/base.py
lib/sqlalchemy/dialects/informix/informixdb.py

index 9fbfbf011717a2efa950716ee49cdb7382bc5048..5240efdc2358768a32cdd12e5179f7dffd2215d7 100644 (file)
@@ -7,8 +7,6 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 """Support for the Informix database.
 
-This dialect is *not* ported to SQLAlchemy 0.6.
-
 This dialect is *not* tested on SQLAlchemy 0.6.
 
 
@@ -19,7 +17,7 @@ import datetime
 
 from sqlalchemy import sql, schema, exc, pool, util
 from sqlalchemy.sql import compiler
-from sqlalchemy.engine import default
+from sqlalchemy.engine import default, reflection
 from sqlalchemy import types as sqltypes
 
 
@@ -50,30 +48,9 @@ class InfoTime(sqltypes.Time):
         return process
 
 
-class InfoBoolean(sqltypes.Boolean):
-    def result_processor(self, dialect, coltype):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is True:
-                return 1
-            elif value is False:
-                return 0
-            elif value is None:
-                return None
-            else:
-                return value and True or False
-        return process
-
 colspecs = {
     sqltypes.DateTime : InfoDateTime,
     sqltypes.Time: InfoTime,
-    sqltypes.Boolean : InfoBoolean,
 }
 
 
@@ -116,12 +93,6 @@ class InfoTypeCompiler(compiler.GenericTypeCompiler):
 
 class InfoSQLCompiler(compiler.SQLCompiler):
 
-    def __init__(self, *args, **kwargs):
-        self.limit = 0
-        self.offset = 0
-
-        compiler.SQLCompiler.__init__(self, *args, **kwargs)
-
     def default_from(self):
         return " from systables where tabname = 'systables' "
 
@@ -129,16 +100,12 @@ class InfoSQLCompiler(compiler.SQLCompiler):
         s = select._distinct and "DISTINCT " or ""
         # only has limit
         if select._limit:
-            off = select._offset or 0
-            s += " FIRST %s " % (select._limit + off)
+            s += " FIRST %s " % select._limit
         else:
             s += ""
         return s
 
     def visit_select(self, select):
-        if select._offset:
-            self.offset = select._offset
-            self.limit  = select._limit or 0
         # the column in order by clause must in select too
 
         def __label(c):
@@ -156,6 +123,8 @@ class InfoSQLCompiler(compiler.SQLCompiler):
         return compiler.SQLCompiler.visit_select(self, select)
 
     def limit_clause(self, select):
+        if select._offset is not None and select._offset > 0:
+            raise NotImplementedError("Informix does not support OFFSET")
         return ""
 
     def visit_function(self, func):
@@ -168,16 +137,13 @@ class InfoSQLCompiler(compiler.SQLCompiler):
         else:
             return compiler.SQLCompiler.visit_function(self, func)
 
-    def visit_clauselist(self, list, **kwargs):
-        return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
 
 class InfoDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, first_pk=False):
         colspec = self.preparer.format_column(column)
         if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
-           isinstance(column.type, sqltypes.Integer) and not getattr(self, 'has_serial', False) and first_pk:
+           isinstance(column.type, sqltypes.Integer) and first_pk:
             colspec += " SERIAL"
-            self.has_serial = True
         else:
             colspec += " " + self.dialect.type_compiler.process(column.type)
             default = self.get_column_default_string(column)
@@ -189,10 +155,6 @@ class InfoDDLCompiler(compiler.DDLCompiler):
 
         return colspec
 
-    def post_create_table(self, table):
-        if hasattr(self, 'has_serial'):
-            del self.has_serial
-        return ''
 
 class InfoIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
@@ -207,18 +169,25 @@ class InfoIdentifierPreparer(compiler.IdentifierPreparer):
 
 class InformixDialect(default.DefaultDialect):
     name = 'informix'
-    # for informix 7.31
-    max_identifier_length = 18
+
+    max_identifier_length = 128 # adjusts at runtime based on server version
+    
     type_compiler = InfoTypeCompiler
-    poolclass = pool.SingletonThreadPool
     statement_compiler = InfoSQLCompiler
     ddl_compiler = InfoDDLCompiler
     preparer = InfoIdentifierPreparer
     colspecs = colspecs
     ischema_names = ischema_names
 
-    ported_sqla_06 = False
-
+    def initialize(self, connection):
+        super(InformixDialect, self).initialize(connection)
+        
+        # http://www.querix.com/support/knowledge-base/error_number_message/error_200
+        if self.server_version_info < (9, 2):
+            self.max_identifier_length = 18
+        else:
+            self.max_identifier_length = 128
+        
     def do_begin(self, connect):
         cu = connect.cursor()
         cu.execute('SET LOCK MODE TO WAIT')
@@ -232,32 +201,14 @@ class InformixDialect(default.DefaultDialect):
         cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower())
         return cursor.first() is not None
 
-    def reflecttable(self, connection, table, include_columns):
-        c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower())
-        rows = c.fetchall()
-        if not rows :
-            raise exc.NoSuchTableError(table.name)
-        else:
-            if table.owner is not None:
-                if table.owner.lower() in [r[0] for r in rows]:
-                    owner = table.owner.lower()
-                else:
-                    raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
-            else:
-                if len(rows)==1:
-                    owner = rows[0][0]
-                else:
-                    raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)
-
-        c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
-                                    where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=?
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from
+                            syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
+                                    where t1.tabid = t2.tabid and t2.tabname=? 
                                       and t3.tabid = t2.tabid and t3.colno = t1.colno
-                                    order by t1.colno""", table.name.lower(), owner)
-        rows = c.fetchall()
-
-        if not rows:
-            raise exc.NoSuchTableError(table.name)
-
+                                    order by t1.colno""", table.name.lower())
+        columns = []
         for name, colattr, collength, default, colno in rows:
             name = name.lower()
             if include_columns and name not in include_columns:
@@ -271,14 +222,14 @@ class InformixDialect(default.DefaultDialect):
                 default = default.split()[-1]
 
             if coltype == 0 or coltype == 13: # char, varchar
-                coltype = ischema_names.get(coltype, InfoString)(collength)
+                coltype = ischema_names[coltype](collength)
                 if default:
                     default = "'%s'" % default
             elif coltype == 5: # decimal
                 precision, scale = (collength & 0xFF00) >> 8, collength & 0xFF
                 if scale == 255:
                     scale = 0
-                coltype = InfoNumeric(precision, scale)
+                coltype = sqltypes.Numeric(precision, scale)
             else:
                 try:
                     coltype = ischema_names[coltype]
@@ -286,54 +237,69 @@ class InformixDialect(default.DefaultDialect):
                     util.warn("Did not recognize type '%s' of column '%s'" %
                               (coltype, name))
                     coltype = sqltypes.NULLTYPE
-
-            colargs = []
-            if default is not None:
-                colargs.append(schema.DefaultClause(sql.text(default)))
-
-            table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs))
-
+            
+            # TODO: nullability ??
+            nullable = True
+            
+            column_info = dict(name=name, type=coltype, nullable=nullable,
+                               default=default)
+            columns.append(column_info)
+        return columns
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
         # FK
         c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
-                                         t4.colname as local_column , t7.tabname as remote_table ,
-                                         t6.colname as remote_column
-                                    from sysconstraints as t1 , systables as t2 ,
-                                         sysindexes as t3 , syscolumns as t4 ,
-                                         sysreferences as t5 , syscolumns as t6 , systables as t7 ,
-                                         sysconstraints as t8 , sysindexes as t9
-                                   where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'R'
-                                     and t3.tabid = t2.tabid and t3.idxname = t1.idxname
-                                     and t4.tabid = t2.tabid and t4.colno = t3.part1
-                                     and t5.constrid = t1.constrid and t8.constrid = t5.primary
-                                     and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname
-                                     and t7.tabid = t5.ptabid""", table.name.lower(), owner)
-        rows = c.fetchall()
-        fks = {}
+                 t4.colname as local_column , t7.tabname as remote_table ,
+                 t6.colname as remote_column
+            from sysconstraints as t1 , systables as t2 ,
+                 sysindexes as t3 , syscolumns as t4 ,
+                 sysreferences as t5 , syscolumns as t6 , systables as t7 ,
+                 sysconstraints as t8 , sysindexes as t9
+           where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'R'
+             and t3.tabid = t2.tabid and t3.idxname = t1.idxname
+             and t4.tabid = t2.tabid and t4.colno = t3.part1
+             and t5.constrid = t1.constrid and t8.constrid = t5.primary
+             and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname
+             and t7.tabid = t5.ptabid""", table.name.lower())
+
+
+        def fkey_rec():
+            return {
+                 'name' : None,
+                 'constrained_columns' : [],
+                 'referred_schema' : None,
+                 'referred_table' : None,
+                 'referred_columns' : []
+             }
+
+        fkeys = util.defaultdict(fkey_rec)
+
         for cons_name, cons_type, local_column, remote_table, remote_column in rows:
-            try:
-                fk = fks[cons_name]
-            except KeyError:
-                fk = ([], [])
-                fks[cons_name] = fk
-            refspec = ".".join([remote_table, remote_column])
-            schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection)
-            if local_column not in fk[0]:
-                fk[0].append(local_column)
-            if refspec not in fk[1]:
-                fk[1].append(refspec)
-
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], None, link_to_name=True))
-
-        # PK
-        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
-                                         t4.colname as local_column
-                                    from sysconstraints as t1 , systables as t2 ,
-                                         sysindexes as t3 , syscolumns as t4
-                                   where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'P'
-                                     and t3.tabid = t2.tabid and t3.idxname = t1.idxname
-                                     and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower(), owner)
-        rows = c.fetchall()
-        for cons_name, cons_type, local_column in rows:
-            table.primary_key.add(table.c[local_column])
 
+            rec = fkeys[cons_name]
+            rec['name'] = cons_name
+            local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+
+            if not rec['referred_table']:
+                rec['referred_table'] = remote_table
+
+            local_cols.append(local_column)
+            remote_cols.append(remote_column)
+
+        return fkeys.values()
+
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        c = connection.execute("""select t4.colname as local_column
+                from sysconstraints as t1 , systables as t2 ,
+                     sysindexes as t3 , syscolumns as t4
+               where t1.tabid = t2.tabid and t2.tabname=? and t1.constrtype = 'P'
+                 and t3.tabid = t2.tabid and t3.idxname = t1.idxname
+                 and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower())
+        return [r[0] for r in c.fetchall()]
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema, **kw):
+        # TODO
+        return []
\ No newline at end of file
index 60d4ba87cb1ebfb3d21b28b082e7488d9a72f64d..722a0f0f4d6dbffd419aaf8d1a2c17fd2e74c8f4 100644 (file)
@@ -1,49 +1,10 @@
 from sqlalchemy.dialects.informix.base import InformixDialect
 from sqlalchemy.engine import default
 
-# for offset
-
-class informix_cursor(object):
-    def __init__(self, con):
-        self.__cursor = con.cursor()
-        self.rowcount = 0
-
-    def offset(self, n):
-        if n > 0:
-            self.fetchmany(n)
-            self.rowcount = self.__cursor.rowcount - n
-            if self.rowcount < 0:
-                self.rowcount = 0
-        else:
-            self.rowcount = self.__cursor.rowcount
-
-    def execute(self, sql, params):
-        if params is None or len(params) == 0:
-            params = []
-
-        return self.__cursor.execute(sql, params)
-
-    def __getattr__(self, name):
-        if name not in ('offset', '__cursor', 'rowcount', '__del__', 'execute'):
-            return getattr(self.__cursor, name)
-
-
 class InfoExecutionContext(default.DefaultExecutionContext):
-    # cursor.sqlerrd
-    # 0 - estimated number of rows returned
-    # 1 - serial value after insert or ISAM error code
-    # 2 - number of rows processed
-    # 3 - estimated cost
-    # 4 - offset of the error into the SQL statement
-    # 5 - rowid after insert
     def post_exec(self):
-        if getattr(self.compiled, "isinsert", False) and self.inserted_primary_key is None:
-            self._last_inserted_ids = [self.cursor.sqlerrd[1]]
-        elif hasattr(self.compiled, 'offset'):
-            self.cursor.offset(self.compiled.offset)
-
-    def create_cursor(self):
-        return informix_cursor(self.connection.connection)
+        if self.isinsert:
+            self._lastrowid = [self.cursor.sqlerrd[1]]
 
 
 class Informix_informixdb(InformixDialect):
@@ -68,6 +29,12 @@ class Informix_informixdb(InformixDialect):
 
         return ([dsn], opt)
 
+    def _get_server_version_info(self, connection):
+        # http://informixdb.sourceforge.net/manual.html#inspecting-version-numbers
+        vers = connection.dbms_version
+        
+        # TODO: not tested
+        return tuple([int(x) for x in vers.split('.')])
 
     def is_disconnect(self, e):
         if isinstance(e, self.dbapi.OperationalError):