]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added reflection to sqlalchemy.dialects.sybase
authorBenjamin Trofatter <bentrofatter@gmail.com>
Tue, 30 Oct 2012 08:52:22 +0000 (03:52 -0500)
committerBenjamin Trofatter <bentrofatter@gmail.com>
Tue, 30 Oct 2012 08:52:22 +0000 (03:52 -0500)
Added missing types supported by Sybase to ischema_names mapping

Created a SybaseInspector similar to the PGInspector, with a cached table_id
  lookup, and added it to the SybaseDialect as the default inspector.

Added the following methods to SybaseDialect:
  get_table_id
  get_columns
  _get_column_info : support method for get_columns
  get_foreign_keys
  get_indexes
  get_pk_constraint
  get_schema_names
  get_view_definition
  get_view_names
Rewrote the following methods to conform to the style of the rest:
  get_table_names
  has_table

Reordered colspec builder to put default clause after "NULL/NOT NULL",
  instead of before.  This fixed a syntax error.

lib/sqlalchemy/dialects/sybase/base.py

index 2d213ed5b2d6a6eabb1cbb4b592466a5f87327cb..e62d37447e50688d4a59ca8f51d03cb80946ef08 100644 (file)
@@ -21,8 +21,9 @@
     and database reflection features are not implemented.
 
 """
-
 import operator
+import re
+
 from sqlalchemy.sql import compiler, expression, text, bindparam
 from sqlalchemy.engine import default, base, reflection
 from sqlalchemy import types as sqltypes
@@ -31,10 +32,10 @@ from sqlalchemy import schema as sa_schema
 from sqlalchemy import util, sql, exc
 
 from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
-                            TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
-                            BIGINT,INT, INTEGER, SMALLINT, BINARY,\
+                            TEXT, DATE, DATETIME, FLOAT, NUMERIC,\
+                            BIGINT, INT, INTEGER, SMALLINT, BINARY,\
                             VARBINARY, DECIMAL, TIMESTAMP, Unicode,\
-                            UnicodeText
+                            UnicodeText, REAL
 
 RESERVED_WORDS = set([
     "add", "all", "alter", "and",
@@ -173,32 +174,68 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler):
         return "UNIQUEIDENTIFIER"
 
 ischema_names = {
-    'integer' : INTEGER,
-    'unsigned int' : INTEGER, # TODO: unsigned flags
-    'unsigned smallint' : SMALLINT, # TODO: unsigned flags
-    'unsigned bigint' : BIGINT, # TODO: unsigned flags
     'bigint': BIGINT,
+    'int' : INTEGER,
+    'integer' : INTEGER,
     'smallint' : SMALLINT,
     'tinyint' : TINYINT,
-    'varchar' : VARCHAR,
-    'long varchar' : TEXT, # TODO
-    'char' : CHAR,
-    'decimal' : DECIMAL,
+    'unsigned bigint' : BIGINT, # TODO: unsigned flags
+    'unsigned int' : INTEGER, # TODO: unsigned flags
+    'unsigned smallint' : SMALLINT, # TODO: unsigned flags
     'numeric' : NUMERIC,
+    'decimal' : DECIMAL,
+    'dec' : DECIMAL,
     'float' : FLOAT,
     'double' : NUMERIC, # TODO
+    'double precision' : NUMERIC, # TODO
+    'real': REAL,
+    'smallmoney': SMALLMONEY,
+    'money': MONEY,
+    'smalldatetime': DATETIME,
+    'datetime': DATETIME,
+    'date': DATE,
+    'time': TIME,
+    'char' : CHAR,
+    'character' : CHAR,
+    'varchar' : VARCHAR,
+    'character varying' : VARCHAR,
+    'char varying' : VARCHAR,
+    'unichar' : UNICHAR,
+    'unicode character' : UNIVARCHAR,
+    'nchar': NCHAR,
+    'national char': NCHAR,
+    'national character': NCHAR,
+    'nvarchar': NVARCHAR,
+    'nchar varying': NVARCHAR,
+    'national char varying': NVARCHAR,
+    'national character varying': NVARCHAR,
+    'text': TEXT,
+    'unitext': UNITEXT,
     'binary' : BINARY,
     'varbinary' : VARBINARY,
-    'bit': BIT,
     'image' : IMAGE,
+    'bit': BIT,
+
+# not in documentation for ASE 15.7
+    'long varchar' : TEXT, # TODO
     'timestamp': TIMESTAMP,
-    'money': MONEY,
-    'smallmoney': MONEY,
     'uniqueidentifier': UNIQUEIDENTIFIER,
 
 }
 
 
+class SybaseInspector(reflection.Inspector):
+
+    def __init__(self, conn):
+        reflection.Inspector.__init__(self, conn)
+
+    def get_table_id(self, table_name, schema=None):
+        """Return the table id from `table_name` and `schema`."""
+
+        return self.dialect.get_table_id(self.bind, table_name, schema,
+                                         info_cache=self.info_cache)
+
+
 class SybaseExecutionContext(default.DefaultExecutionContext):
     _enable_identity_insert = False
 
@@ -246,7 +283,6 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
                         self.root_connection.connection.connection,
                         True)
 
-
     def post_exec(self):
         if self.isddl:
             self.set_ddl_autocommit(self.root_connection, False)
@@ -348,16 +384,16 @@ class SybaseDDLCompiler(compiler.DDLCompiler):
                 # TODO: need correct syntax for this
                 colspec += " IDENTITY(%s,%s)" % (start, increment)
         else:
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
             if column.nullable is not None:
                 if not column.nullable or column.primary_key:
                     colspec += " NOT NULL"
                 else:
                     colspec += " NULL"
 
-            default = self.get_column_default_string(column)
-            if default is not None:
-                colspec += " DEFAULT " + default
-
         return colspec
 
     def visit_drop_index(self, drop):
@@ -388,6 +424,7 @@ class SybaseDialect(default.DefaultDialect):
     statement_compiler = SybaseSQLCompiler
     ddl_compiler = SybaseDDLCompiler
     preparer = SybaseIdentifierPreparer
+    inspector = SybaseInspector
 
     def _get_default_schema_name(self, connection):
         return connection.scalar(
@@ -403,39 +440,381 @@ class SybaseDialect(default.DefaultDialect):
         else:
             self.max_identifier_length = 255
 
+    @reflection.cache
+    def get_table_id(self, connection, table_name, schema=None, **kw):
+        """Fetch the id for schema.table_name.
+
+        Several reflection methods require the table id.  The idea for using
+        this method is that it can be fetched one time and cached for
+        subsequent calls.
+
+        """
+
+        table_id = None
+        if schema is None:
+            schema = self.default_schema_name
+
+        TABLEID_SQL = text("""
+          SELECT o.id AS id
+          FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
+          WHERE u.name = :schema_name
+              AND o.name = :table_name
+              AND o.type = 'U'
+        """)
+
+        # Py2K
+        if isinstance(schema, unicode):
+            schema = schema.encode("ascii")
+        if isinstance(table_name, unicode):
+            table_name = table_name.encode("ascii")
+        # end Py2K
+        result = connection.execute(TABLEID_SQL,
+                                    schema_name=schema,
+                                    table_name=table_name)
+        table_id = result.scalar()
+        if table_id is None:
+            raise exc.NoSuchTableError(table_name)
+        return table_id
+
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        table_id = self.get_table_id(connection, table_name, schema,
+                                     info_cache=kw.get("info_cache"))
+
+        COLUMN_SQL = text("""
+          SELECT col.name AS name,
+                 t.name AS type, 
+                 (col.status & 8) AS nullable,
+                 (col.status & 128) AS autoincrement,
+                 com.text AS 'default',
+                 col.prec AS precision,
+                 col.scale AS scale,
+                 col.length AS length
+          FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON 
+              col.cdefault = com.id
+          WHERE col.usertype = t.usertype 
+              AND col.id = :table_id
+          ORDER BY col.colid
+        """)
+
+        results = connection.execute(COLUMN_SQL, table_id=table_id)
+
+        columns = []
+        for (name, type_, nullable, autoincrement, default, precision, scale, 
+             length) in results:
+            col_info = self._get_column_info(name, type_, bool(nullable),
+                             bool(autoincrement), default, precision, scale,
+                             length)
+            columns.append(col_info)
+
+        return columns
+
+    def _get_column_info(self, name, type_, nullable, autoincrement, default,
+            precision, scale, length):
+
+        coltype = self.ischema_names.get(type_, None)
+
+        kwargs = {}
+
+        if coltype in (NUMERIC, DECIMAL):
+            args = (precision, scale)
+        elif coltype == FLOAT:
+            args = (precision,)
+        elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
+            args = (length,)
+        else:
+            args = ()
+
+        if coltype:
+            coltype = coltype(*args, **kwargs)
+            #is this necessary
+            #if is_array:
+            #     coltype = ARRAY(coltype)
+        else:
+            util.warn("Did not recognize type '%s' of column '%s'" %
+                      (type_, name))
+            coltype = sqltypes.NULLTYPE
+
+        if default:
+            default = re.sub("DEFAULT", "", default).strip()
+            default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
+        else:
+            default = None
+        column_info = dict(name=name, type=coltype, nullable=nullable,
+                           default=default, autoincrement=autoincrement)
+        return column_info
+
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+        table_id = self.get_table_id(connection, table_name, schema,
+                                     info_cache=kw.get("info_cache"))
+    
+        table_cache = {}
+        column_cache = {}
+        foreign_keys = []
+    
+        table_cache[table_id] = table_name
+    
+        COLUMN_SQL = text("""
+          SELECT c.colid AS id, c.name AS name
+          FROM syscolumns c
+          WHERE c.id = :table_id
+        """)
+    
+        results = connection.execute(COLUMN_SQL, table_id=table_id)
+        columns = {}
+        for col in results:
+            columns[col["id"]] = col["name"]
+        column_cache[table_id] = columns
+    
+        REFCONSTRAINT_SQL = text("""
+          SELECT o.name AS name, r.reftabid AS reftable_id,
+            r.keycnt AS 'count',
+            r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
+            r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
+            r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
+            r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
+            r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
+            r.fokey16 AS fokey16,
+            r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
+            r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
+            r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
+            r.refkey10 AS refkey10, r.refkey11 AS refkey11,
+            r.refkey12 AS refkey12, r.refkey13 AS refkey13,
+            r.refkey14 AS refkey14, r.refkey15 AS refkey15,
+            r.refkey16 AS refkey16
+          FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
+          WHERE r.tableid = :table_id
+        """)
+        referential_constraints = connection.execute(REFCONSTRAINT_SQL,
+                                                     table_id=table_id)
+    
+        REFTABLE_SQL = text("""
+          SELECT o.id AS id, o.name AS name, u.name AS 'schema'
+          FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+          WHERE o.id = :table_id
+        """)
+
+        for r in referential_constraints:
+    
+            reftable_id = r["reftable_id"]
+    
+            if reftable_id not in table_cache:
+                c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
+                reftable = c.fetchone()
+                c.close()
+                table_cache[reftable_id] = {"name": reftable["name"],
+                                            "schema": reftable["schema"]}
+    
+                results = connection.execute(COLUMN_SQL, table_id=reftable_id)
+                reftable_columns = {}
+                for col in results:
+                    reftable_columns[col["id"]] = col["name"]
+                column_cache[reftable_id] = reftable_columns
+    
+            reftable = table_cache[reftable_id]
+            reftable_columns = column_cache[reftable_id]
+    
+            constrained_columns = []
+            referred_columns = []
+            for i in range(1, r["count"]+1):
+                constrained_columns.append(columns[r["fokey%i" % i]])
+                referred_columns.append(reftable_columns[r["refkey%i" % i]])
+    
+            fk_info = {
+                    "constrained_columns": constrained_columns,
+                    "referred_schema": reftable["schema"],
+                    "referred_table": reftable["name"],
+                    "referred_columns": referred_columns,
+                    "name": r["name"]
+                }
+    
+            foreign_keys.append(fk_info)
+    
+        return foreign_keys
+
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None, **kw):
+        table_id = self.get_table_id(connection, table_name, schema,
+                                     info_cache=kw.get("info_cache"))
+
+        INDEX_SQL = text("""
+          SELECT object_name(i.id) AS table_name,
+                 i.keycnt AS 'count',
+                 i.name AS name,
+                 (i.status & 0x2) AS 'unique',
+                 index_col(object_name(i.id), i.indid, 1) AS col_1,
+                 index_col(object_name(i.id), i.indid, 2) AS col_2,
+                 index_col(object_name(i.id), i.indid, 3) AS col_3,
+                 index_col(object_name(i.id), i.indid, 4) AS col_4,
+                 index_col(object_name(i.id), i.indid, 5) AS col_5,
+                 index_col(object_name(i.id), i.indid, 6) AS col_6,
+                 index_col(object_name(i.id), i.indid, 7) AS col_7,
+                 index_col(object_name(i.id), i.indid, 8) AS col_8,
+                 index_col(object_name(i.id), i.indid, 9) AS col_9,
+                 index_col(object_name(i.id), i.indid, 10) AS col_10,
+                 index_col(object_name(i.id), i.indid, 11) AS col_11,
+                 index_col(object_name(i.id), i.indid, 12) AS col_12,
+                 index_col(object_name(i.id), i.indid, 13) AS col_13,
+                 index_col(object_name(i.id), i.indid, 14) AS col_14,
+                 index_col(object_name(i.id), i.indid, 15) AS col_15,
+                 index_col(object_name(i.id), i.indid, 16) AS col_16
+          FROM sysindexes i, sysobjects o
+          WHERE o.id = i.id
+            AND o.id = :table_id
+            AND (i.status & 2048) = 0
+            AND i.indid BETWEEN 1 AND 254
+            AND o.type = 'U'
+        """)
+
+        results = connection.execute(INDEX_SQL, table_id=table_id)
+        indexes = []
+        for r in results:
+            column_names = []
+            for i in range(1, r["count"]):
+                column_names.append(r["col_%i" % (i,)])
+            index_info = {"name": r["name"],
+                          "unique": bool(r["unique"]),
+                          "column_names": column_names}
+            indexes.append(index_info)
+
+        return indexes
+
+    @reflection.cache
+    def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+        table_id = self.get_table_id(connection, table_name, schema,
+                                     info_cache=kw.get("info_cache"))
+
+        PK_SQL = text("""
+          SELECT object_name(i.id) AS table_name,
+                 i.keycnt AS 'count',
+                 i.name AS name,
+                 index_col(object_name(i.id), i.indid, 1) AS pk_1,
+                 index_col(object_name(i.id), i.indid, 2) AS pk_2,
+                 index_col(object_name(i.id), i.indid, 3) AS pk_3,
+                 index_col(object_name(i.id), i.indid, 4) AS pk_4,
+                 index_col(object_name(i.id), i.indid, 5) AS pk_5,
+                 index_col(object_name(i.id), i.indid, 6) AS pk_6,
+                 index_col(object_name(i.id), i.indid, 7) AS pk_7,
+                 index_col(object_name(i.id), i.indid, 8) AS pk_8,
+                 index_col(object_name(i.id), i.indid, 9) AS pk_9,
+                 index_col(object_name(i.id), i.indid, 10) AS pk_10,
+                 index_col(object_name(i.id), i.indid, 11) AS pk_11,
+                 index_col(object_name(i.id), i.indid, 12) AS pk_12,
+                 index_col(object_name(i.id), i.indid, 13) AS pk_13,
+                 index_col(object_name(i.id), i.indid, 14) AS pk_14,
+                 index_col(object_name(i.id), i.indid, 15) AS pk_15,
+                 index_col(object_name(i.id), i.indid, 16) AS pk_16
+          FROM sysindexes i, sysobjects o
+          WHERE o.id = i.id
+            AND o.id = :table_id
+            AND (i.status & 2048) = 2048
+            AND i.indid BETWEEN 1 AND 254
+            AND o.type = 'U'
+        """)
+
+        results = connection.execute(PK_SQL, table_id=table_id)
+        pks = results.fetchone()
+        results.close()
+        
+        constrained_columns = []
+        for i in range(1, pks["count"]+1):
+            constrained_columns.append(pks["pk_%i" % (i,)])
+        return {"constrained_columns": constrained_columns,
+                "name": pks["name"]}
+
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
+
+        SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
+
+        schemas = connection.execute(SCHEMA_SQL)
+
+        return [s["name"] for s in schemas]
+
     @reflection.cache
     def get_table_names(self, connection, schema=None, **kw):
         if schema is None:
             schema = self.default_schema_name
 
-        result = connection.execute(
-                    text("select sysobjects.name from sysobjects, sysusers "
-                         "where sysobjects.uid=sysusers.uid and "
-                         "sysusers.name=:schemaname and "
-                         "sysobjects.type='U'",
-                         bindparams=[
-                                  bindparam('schemaname', schema)
-                                  ])
-         )
-        return [r[0] for r in result]
-
-    def has_table(self, connection, tablename, schema=None):
+        TABLE_SQL = text("""
+          SELECT o.name AS name
+          FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+          WHERE u.name = :schema_name
+            AND o.type = 'U'
+        """)
+
+        # Py2K
+        if isinstance(schema, unicode):
+            schema = schema.encode("ascii")
+        # end Py2K
+        tables = connection.execute(TABLE_SQL, schema_name=schema)
+
+        return [t["name"] for t in tables]
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        if schema is None:
+            schema = self.default_schema_name
+
+        VIEW_DEF_SQL = text("""
+          SELECT c.text
+          FROM syscomments c JOIN sysobjects o ON c.id = o.id
+          WHERE o.name = :view_name
+            AND o.type = 'V'
+        """)
+
+        # Py2K
+        if isinstance(view_name, unicode):
+            view_name = view_name.encode("ascii")
+        # end Py2K
+        view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
+
+        return view.scalar()
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        if schema is None:
+            schema = self.default_schema_name
+
+        VIEW_SQL = text("""
+          SELECT o.name AS name
+          FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+          WHERE u.name = :schema_name
+            AND o.type = 'V'
+        """)
+
+        # Py2K
+        if isinstance(schema, unicode):
+            schema = schema.encode("ascii")
+        # end Py2K
+        views = connection.execute(VIEW_SQL, schema_name=schema)
+
+        return [v["name"] for v in views]
+
+    def has_table(self, connection, table_name, schema=None):
         if schema is None:
             schema = self.default_schema_name
 
-        result = connection.execute(
-                    text("select sysobjects.name from sysobjects, sysusers "
-                         "where sysobjects.uid=sysusers.uid and "
-                         "sysobjects.name=:tablename and "
-                         "sysusers.name=:schemaname and "
-                         "sysobjects.type='U'",
-                         bindparams=[
-                                  bindparam('tablename', tablename),
-                                  bindparam('schemaname', schema)
-                                  ])
-                 )
+        HAS_TABLE_SQL = text("""
+          SELECT o.name
+          FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+          WHERE o.name = :table_name
+             AND u.name = :schema_name
+             AND o.type = 'U'
+        """)
+
+        # Py2K
+        if isinstance(schema, unicode):
+            schema = schema.encode("ascii")
+        if isinstance(table_name, unicode):
+            table_name = table_name.encode("ascii")
+        # end Py2K
+        result = connection.execute(HAS_TABLE_SQL, table_name=table_name,
+                                    schema_name=schema)
         return result.scalar() is not None
 
-    def reflecttable(self, connection, table, include_columns):
-        raise NotImplementedError()
+    #def reflecttable(self, connection, table, include_columns):
+    #    raise NotImplementedError()