]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
DOH !
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 17:13:42 +0000 (17:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 17:13:42 +0000 (17:13 +0000)
lib/sqlalchemy/databases/mssql.py [new file with mode: 0644]

diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
new file mode 100644 (file)
index 0000000..6671276
--- /dev/null
@@ -0,0 +1,492 @@
+# mssql.py
+
+"""
+notes:
+  supports both pymssql and adodbapi interfaces
+
+  IDENTITY columns are supported by using SA schema.Sequence() objects. In other words:
+         Table('test', mss_engine,
+                Column('id',   Integer, Sequence('blah',100,10), primary_key=True),
+                Column('name', String(20))
+              ).create()
+
+         would yield:
+         CREATE TABLE test (
+           id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
+           name VARCHAR(20)
+           )
+  note that the start & increment values for sequences are optional and will default to 1,1
+
+  support for SET IDENTITY_INSERT ON mode (automagic on / off for INSERTs)
+
+  support for auto-fetching of @@IDENTITY on insert
+
+  select.limit implemented as SELECT TOP n
+
+
+Known issues / TODO:
+  no support for more than one IDENTITY column per table
+  no support for table reflection of IDENTITY columns with (seed,increment) values other than (1,1)
+  no support for GUID type columns (yet)
+  pymssql has problems with transaction control that this module attempts to work around
+  pymssql has problems with binary and unicode data that this module does NOT work around
+  adodbapi fails testtypes.py unit test on unicode data too -- issue with the test?
+
+"""
+
+import sys, StringIO, string, types, re, datetime
+
+import sqlalchemy.sql as sql
+import sqlalchemy.engine as engine
+import sqlalchemy.schema as schema
+import sqlalchemy.ansisql as ansisql
+import sqlalchemy.types as sqltypes
+from sqlalchemy import *
+
+try:
+    import adodbapi as dbmodule
+    # ADODBAPI has a non-standard Connection method
+    connect = dbmodule.Connection
+    make_connect_string = lambda keys: \
+        [["Provider=SQLOLEDB;Data Source=%s;User Id=%s;Password=%s;Initial Catalog=%s" % (
+            keys["host"], keys["user"], keys["password"], keys["database"])], {}]
+    do_commit = False
+except:
+    try:
+        import pymssql as dbmodule
+        connect = dbmodule.connect
+        # pymmsql doesn't have a Binary method.  we use string
+        dbmodule.Binary = lambda st: str(st)
+        make_connect_string = lambda keys:  \
+                    [[], keys]
+        do_commit = True
+    except:
+        dbmodule = None
+        make_connect_string = lambda keys: [[],{}]
+    
+class MSNumeric(sqltypes.Numeric):
+    def convert_result_value(self, value, engine):
+        return value
+
+    def convert_bind_param(self, value, engine):
+        if value is None:
+            # Not sure that this exception is needed
+            return value
+        else:
+            return str(value) 
+
+    def get_col_spec(self):
+        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+
+class MSFloat(sqltypes.Float):
+    def get_col_spec(self):
+        return "FLOAT(%(precision)s)" % {'precision': self.precision}
+    def convert_bind_param(self, value, engine):
+        """By converting to string, we can use Decimal types round-trip."""
+        return str(value) 
+
+class MSInteger(sqltypes.Integer):
+    def get_col_spec(self):
+        return "INTEGER"
+
+class MSSmallInteger(sqltypes.Smallinteger):
+    def get_col_spec(self):
+        return "SMALLINT"
+
+class MSDateTime(sqltypes.DateTime):
+    def get_col_spec(self):
+        return "DATETIME"
+
+    def convert_bind_param(self, value, engine):
+        if hasattr(value, "isoformat"):
+            return value.isoformat(' ')
+        else:
+            return value
+
+    def convert_result_value(self, value, engine):
+        # adodbapi will return datetimes with empty time values as datetime.date() objects. Promote them back to full datetime.datetime()
+        if value and not hasattr(value, 'second'):
+            return datetime.datetime(value.year, value.month, value.day)
+        return value
+
+class MSDate(sqltypes.Date):
+    def get_col_spec(self):
+        return "SMALLDATETIME"
+    
+    def convert_bind_param(self, value, engine):
+        if value and hasattr(value, "isoformat"):
+            return value.isoformat()
+        return value
+
+    def convert_result_value(self, value, engine):
+        # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
+        if value and hasattr(value, 'second'):
+            return value.date()
+        return value
+
+class MSText(sqltypes.TEXT):
+    def get_col_spec(self):
+        return "TEXT"
+class MSString(sqltypes.String):
+    def get_col_spec(self):
+        return "VARCHAR(%(length)s)" % {'length' : self.length}
+class MSChar(sqltypes.CHAR):
+    def get_col_spec(self):
+        return "CHAR(%(length)s)" % {'length' : self.length}
+class MSBinary(sqltypes.Binary):
+    def get_col_spec(self):
+        return "IMAGE"
+class MSBoolean(sqltypes.Boolean):
+    def get_col_spec(self):
+        return "BIT"
+        
+colspecs = {
+    sqltypes.Integer : MSInteger,
+    sqltypes.Smallinteger: MSSmallInteger,
+    sqltypes.Numeric : MSNumeric,
+    sqltypes.Float : MSFloat,
+    sqltypes.DateTime : MSDateTime,
+    sqltypes.Date : MSDate,
+    sqltypes.String : MSString,
+    sqltypes.Binary : MSBinary,
+    sqltypes.Boolean : MSBoolean,
+    sqltypes.TEXT : MSText,
+    sqltypes.CHAR: MSChar,
+}
+
+ischema_names = {
+    'int' : MSInteger,
+    'smallint' : MSSmallInteger,
+    'varchar' : MSString,
+    'char' : MSChar,
+    'text' : MSText,
+    'decimal' : MSNumeric,
+    'numeric' : MSNumeric,
+    'float' : MSFloat,
+    'datetime' : MSDateTime,
+    'smalldatetime' : MSDate,
+    'binary' : MSBinary,
+    'bit': MSBoolean,
+    'real' : MSFloat,
+    'image' : MSBinary
+}
+
+def engine(opts, **params):
+    return MSSQLEngine(opts, **params)
+
+def descriptor():
+    return {'name':'mssql',
+    'description':'MSSQL',
+    'arguments':[
+        ('user',"Database Username",None),
+        ('password',"Database Password",None),
+        ('db',"Database Name",None),
+        ('host',"Hostname", None),
+    ]}
+
+class MSSQLEngine(ansisql.ANSISQLEngine):
+    def __init__(self, opts, module = None, **params):
+        if module is None:
+            self.module = dbmodule
+        self.opts = opts or {}
+        ansisql.ANSISQLEngine.__init__(self, **params)
+
+    def connect_args(self):
+        return make_connect_string(self.opts)
+
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
+
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
+    def supports_sane_rowcount(self):
+        return True
+
+    def compiler(self, statement, bindparams, **kwargs):
+        return MSSQLCompiler(statement, bindparams, engine=self, **kwargs)
+
+    def schemagenerator(self, **params):
+        return MSSQLSchemaGenerator(self, **params)
+
+    def schemadropper(self, **params):
+        return MSSQLSchemaDropper(self, **params)
+
+    def get_default_schema_name(self):
+        return "dbo"
+        
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+            
+    def do_begin(self, connection):
+        """implementations might want to put logic here for turning autocommit on/off, etc."""
+        if do_commit:
+            pass  
+
+    def _execute(self, c, statement, parameters):
+        try:
+            c.execute(statement, parameters)
+            self.context.rowcount = c.rowcount
+            c.DBPROP_COMMITPRESERVE = "Y"
+        except Exception, e:
+            # del c.parent  # Close the Parent Connection, delete it from the pool
+            raise exceptions.SQLError(statement, parameters, e)
+
+
+    def do_rollback(self, connection):
+        """implementations might want to put logic here for turning autocommit on/off, etc."""
+        if do_commit:
+            try:
+                # connection.rollback() for pymmsql failed sometimes--the begin tran doesn't show up
+                # this is a workaround that seems to be handle it.
+                r = self.raw_connection(connection)
+                r.query("if @@trancount > 0 rollback tran")
+                r.fetch_array()
+                r.query("begin tran")
+                r.fetch_array()
+            except:
+                pass
+        try:
+            del connection
+        except:
+            raise
+
+    def raw_connection(self, connection):
+        """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
+        try:
+            return connection.connection.__dict__['_pymssqlCnx__cnx']
+        except:
+            return connection.connection.adoConn
+
+    def do_commit(self, connection):
+        """implementations might want to put logic here for turning autocommit on/off, etc.
+            do_commit is set for pymmsql connections--ADO seems to handle transactions without any issue 
+        """
+        # ADO Uses Implicit Transactions.
+        if do_commit:
+            # This is very pymssql specific.  We use this instead of its commit, because it hangs on failed rollbacks.
+            # By using the "if" we don't assume an open transaction--much better.
+            r = self.raw_connection(connection)
+            r.query("if @@trancount > 0 commit tran")
+            r.fetch_array()
+            r.query("begin tran")
+            r.fetch_array()
+        else:
+            pass
+            #connection.supportsTransactions = 1
+            try:
+                pass
+                #connection.adoConn.CommitTrans()
+            except:
+                pass
+                #connection.adoConn.execute("begin trans", {})
+            #connection.adoConn.BeginTrans()
+
+    def connection(self):
+        """returns a managed DBAPI connection from this SQLEngine's connection pool."""
+        c = self._pool.connect()
+        c.supportsTransactions = 0
+        return c
+
+    def pre_exec(self, proxy, compiled, parameters, **kwargs):
+        """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if needed. """
+        if getattr(compiled, "isinsert", False):
+            self.context.IINSERT = False
+            self.context.HASIDENT = False
+            for c in compiled.statement.table.c:
+                if hasattr(c,'sequence'):
+                    self.context.HASIDENT = True
+                    if parameters.has_key(c.name):
+                        self.context.IINSERT = True
+                    break
+            if self.context.IINSERT:
+                proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
+
+    def post_exec(self, proxy, compiled, parameters, **kwargs):
+        """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """
+        if getattr(compiled, "isinsert", False):
+            if self.context.IINSERT:
+                proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
+                self.context.IINSERT = False
+            elif self.context.HASIDENT:
+                cursor = proxy("SELECT @@IDENTITY AS lastrowid")
+                row = cursor.fetchone()
+                self.context.last_inserted_ids = [row[0]]
+            self.context.HASIDENT = False
+            
+    def dbapi(self):
+        return self.module
+
+    def reflecttable(self, table):
+        import sqlalchemy.databases.information_schema as ischema
+        
+        # Get base columns
+        if table.schema is not None:
+            current_schema = table.schema
+        else:
+            current_schema = self.get_default_schema_name()
+
+        columns = ischema.gen_columns.toengine(self)
+        s = select([columns],
+                   current_schema and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name,
+                   order_by=[columns.c.ordinal_position])
+        
+        c = s.execute()
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+
+            (name, type, nullable, charlen, numericprec, numericscale, default) = (
+                row[columns.c.column_name], 
+                row[columns.c.data_type], 
+                row[columns.c.is_nullable] == 'YES', 
+                row[columns.c.character_maximum_length],
+                row[columns.c.numeric_precision],
+                row[columns.c.numeric_scale],
+                row[columns.c.column_default]
+            )
+
+            args = []
+            for a in (charlen, numericprec, numericscale):
+                if a is not None:
+                    args.append(a)
+                    coltype = ischema_names[type]
+        
+            coltype = coltype(*args)
+            colargs= []
+            if default is not None:
+                colargs.append(PassiveDefault(sql.text(default)))
+                
+            table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
+
+
+        # We also run an sp_columns to check for identity columns:
+        # FIXME: note that this only fetches the existence of an identity column, not it's properties like (seed, increment)
+        cursor = table.engine.execute("sp_columns " + table.name, {})
+        while True:
+            row = cursor.fetchone()
+            if row is None:
+                break
+            col_name, type_name = row[3], row[5]
+            if type_name.endswith("identity"):
+                ic = table.c[col_name]
+                # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
+                ic.sequence = schema.Sequence(ic.name + '_identity')
+
+        # Add constraints
+        RR = ischema.gen_ref_constraints.toengine(self)    #information_schema.referential_constraints
+        TC = ischema.gen_constraints.toengine(self)        #information_schema.table_constraints
+        C  = ischema.gen_column_constraints.toengine(self).alias('C') #information_schema.constraint_column_usage: the constrained column 
+        R  = ischema.gen_column_constraints.toengine(self).alias('R') #information_schema.constraint_column_usage: the referenced column
+
+        fromjoin = TC.join(RR, RR.c.constraint_name == TC.c.constraint_name).join(C, C.c.constraint_name == RR.c.constraint_name)
+        fromjoin = fromjoin.join(R, R.c.constraint_name == RR.c.unique_constraint_name)
+
+        s = select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name,
+                    R.c.table_schema, R.c.table_name, R.c.column_name],
+                   and_(RR.c.constraint_schema == current_schema,  C.c.table_name == table.name),
+                   from_obj = [fromjoin]
+                   )
+               
+        c = s.execute()
+
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            (type, constrained_column, referred_schema, referred_table, referred_column) = (
+                row[colmap[0]],
+                row[colmap[3]],
+                row[colmap[4]],
+                row[colmap[5]],
+                row[colmap[6]]
+                )
+
+            if type=='PRIMARY KEY':
+                table.c[constrained_column]._set_primary_key()
+            elif type=='FOREIGN KEY':
+                remotetable = Table(referred_table, self, autoload = True, schema=referred_schema)
+                table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column]))
+        
+
+
+class MSSQLCompiler(ansisql.ANSICompiler):
+    def __init__(self, *args, **kwargs):
+        super(MSSQLCompiler, self).__init__(*args, **kwargs)
+        self.tablealiases = {}
+
+    def visit_select_precolumns(self, select):
+        """ MS-SQL puts TOP, it's version of LIMIT here """
+        s = select.distinct and "DISTINCT " or ""
+        if (select.limit):
+            s += "TOP %s " % (select.limit,)
+        return s
+
+    def limit_clause(self, select):
+        # Limit in mssql is after the select keyword; MSsql has no support for offset
+        return ""
+
+            
+    def visit_table(self, table):
+        # alias schema-qualified tables
+         if self.tablealiases.has_key(table):
+            super(MSSQLCompiler, self).visit_table(table)
+         elif getattr(table, 'schema', None) is not None:
+            alias = table.alias()
+            self.tablealiases[table] = alias
+            alias.accept_visitor(self)
+            self.froms[('alias', table)] = self.froms[table]
+            for c in alias.c:
+                c.accept_visitor(self)
+            self.tablealiases[alias] = self.froms[table]
+            self.froms[table] = self.froms[alias]
+
+    def visit_alias(self, alias):
+        # translate for schema-qualified table aliases
+        if self.froms.has_key(('alias', alias.original)):
+            self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
+            self.strings[alias] = ""
+        else:
+            super(MSSQLCompiler, self).visit_alias(alias)
+
+    def visit_column(self, column):
+        # translate for schema-qualified table aliases
+        super(MSSQLCompiler, self).visit_column(column)
+        if column.table is not None and self.tablealiases.has_key(column.table):
+            self.strings[column] = \
+                self.strings[self.tablealiases[column.table]._get_col_by_original(column.original)]
+
+        
+class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
+    def get_column_specification(self, column, override_pk=False, first_pk=False):
+        colspec = column.name + " " + column.type.get_col_spec()
+
+        # install a IDENTITY Sequence if we have an implicit IDENTITY column
+        if column.primary_key and isinstance(column.type, types.Integer):
+            if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
+                column.sequence = schema.Sequence(column.name + '_seq')
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+
+        if hasattr(column, 'sequence'):
+            colspec += " IDENTITY(%s,%s)" % (column.sequence.start or 1, column.sequence.increment or 1)
+        else:
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
+        if column.primary_key:
+            if not override_pk:
+                colspec += " PRIMARY KEY"
+        if column.foreign_key:
+            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name)
+        
+        return colspec
+
+
+class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
+    def visit_index(self, index):
+        self.append("\nDROP INDEX " + index.table.name + "." + index.name)
+        self.execute()