]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged r. morrisons 0.2 update from branch to trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 May 2006 00:16:45 +0000 (00:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 May 2006 00:16:45 +0000 (00:16 +0000)
lib/sqlalchemy/databases/mssql.py

index a8124537ac889e6e1b5d8a366980ab586acf9c39..94fbd622f4386d178f102f76c1e7a7a3cb610b42 100644 (file)
@@ -38,10 +38,11 @@ import sys, StringIO, string, types, re, datetime
 
 import sqlalchemy.sql as sql
 import sqlalchemy.engine as engine
+import sqlalchemy.engine.default as default
 import sqlalchemy.schema as schema
 import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
-from sqlalchemy import *
+import sqlalchemy.exceptions as exceptions
 
 try:
     import adodbapi as dbmodule
@@ -65,10 +66,10 @@ except:
         make_connect_string = lambda keys: [[],{}]
     
 class MSNumeric(sqltypes.Numeric):
-    def convert_result_value(self, value, engine):
+    def convert_result_value(self, value, dialect):
         return value
 
-    def convert_bind_param(self, value, engine):
+    def convert_bind_param(self, value, dialect):
         if value is None:
             # Not sure that this exception is needed
             return value
@@ -81,7 +82,7 @@ class MSNumeric(sqltypes.Numeric):
 class MSFloat(sqltypes.Float):
     def get_col_spec(self):
         return "FLOAT(%(precision)s)" % {'precision': self.precision}
-    def convert_bind_param(self, value, engine):
+    def convert_bind_param(self, value, dialect):
         """By converting to string, we can use Decimal types round-trip."""
         return str(value) 
 
@@ -97,13 +98,14 @@ class MSDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATETIME"
 
-    def convert_bind_param(self, value, engine):
+    def convert_bind_param(self, value, dialect):
         if hasattr(value, "isoformat"):
-            return value.isoformat(' ')
+            #return value.isoformat(' ')
+            return value.strftime('%Y-%m-%d %H:%M:%S')            # isoformat() bings on apodbapi -- reported/suggested by Peter Buschman
         else:
             return value
 
-    def convert_result_value(self, value, engine):
+    def convert_result_value(self, value, dialect):
         # adodbapi will return datetimes with empty time values as datetime.date() objects. Promote them back to full datetime.datetime()
         if value and not hasattr(value, 'second'):
             return datetime.datetime(value.year, value.month, value.day)
@@ -113,12 +115,12 @@ class MSDate(sqltypes.Date):
     def get_col_spec(self):
         return "SMALLDATETIME"
     
-    def convert_bind_param(self, value, engine):
+    def convert_bind_param(self, value, dialect):
         if value and hasattr(value, "isoformat"):
             return value.isoformat()
         return value
 
-    def convert_result_value(self, value, engine):
+    def convert_result_value(self, value, dialect):
         # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
         if value and hasattr(value, 'second'):
             return value.date()
@@ -184,16 +186,49 @@ def descriptor():
         ('host',"Hostname", None),
     ]}
 
-class MSSQLEngine(ansisql.ANSISQLEngine):
-    def __init__(self, opts, module = None, **params):
-        if module is None:
-            self.module = dbmodule
-        self.opts = opts or {}
-        ansisql.ANSISQLEngine.__init__(self, **params)
+class MSSQLExecutionContext(default.DefaultExecutionContext):
+    def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
+        """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if needed. """
+        if getattr(compiled, "isinsert", False):
+            self.IINSERT = False
+            self.HASIDENT = False
+            for c in compiled.statement.table.c:
+                if hasattr(c,'sequence'):
+                    self.HASIDENT = True
+                    if parameters.has_key(c.name):
+                        self.IINSERT = True
+                    break
+            if self.IINSERT:
+                proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
+
+    def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
+        """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """
+        if getattr(compiled, "isinsert", False):
+            if self.IINSERT:
+                proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
+                self.IINSERT = False
+            elif self.HASIDENT:
+                cursor = proxy("SELECT @@IDENTITY AS lastrowid")
+                row = cursor.fetchone()
+                self.last_inserted_ids = [row[0]]
+            self.HASIDENT = False
+
+class MSSQLDialect(ansisql.ANSIDialect):            
+    def __init__(self, module = None, **params):
+        self.module = module or dbmodule
+        self.opts = {}
+        ansisql.ANSIDialect.__init__(self, **params)
+
+    def create_connect_args(self, url):
+        self.opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
+        return ([], self.opts)
 
     def connect_args(self):
         return make_connect_string(self.opts)
 
+    def create_execution_context(self):
+        return MSSQLExecutionContext(self)
+
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
@@ -204,13 +239,16 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
         return True
 
     def compiler(self, statement, bindparams, **kwargs):
-        return MSSQLCompiler(statement, bindparams, engine=self, **kwargs)
+        return MSSQLCompiler(self, statement, bindparams, **kwargs)
 
-    def schemagenerator(self, **params):
-        return MSSQLSchemaGenerator(self, **params)
+    def schemagenerator(self, *args, **kwargs):
+        return MSSQLSchemaGenerator(*args, **kwargs)
 
-    def schemadropper(self, **params):
-        return MSSQLSchemaDropper(self, **params)
+    def schemadropper(self, *args, **kwargs):
+        return MSSQLSchemaDropper(*args, **kwargs)
+
+    def defaultrunner(self, engine, proxy):
+        return MSSQLDefaultRunner(engine, proxy)
 
     def get_default_schema_name(self):
         return "dbo"
@@ -229,9 +267,9 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
             self.context.rowcount = c.rowcount
             c.DBPROP_COMMITPRESERVE = "Y"
         except Exception, e:
-            # del c.parent  # Close the Parent Connection, delete it from the pool
-            raise exceptions.SQLError(statement, parameters, e)
+            # del c.parent  # Close the Parent Connection, delete it from the pool        columns = ischema.columns.toengine(self)
 
+            raise exceptions.SQLError(statement, parameters, e)
 
     def do_rollback(self, connection):
         """implementations might want to put logic here for turning autocommit on/off, etc."""
@@ -288,36 +326,11 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
         c.supportsTransactions = 0
         return c
 
-    def pre_exec(self, proxy, compiled, parameters, **kwargs):
-        """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if needed. """
-        if getattr(compiled, "isinsert", False):
-            self.context.IINSERT = False
-            self.context.HASIDENT = False
-            for c in compiled.statement.table.c:
-                if hasattr(c,'sequence'):
-                    self.context.HASIDENT = True
-                    if parameters.has_key(c.name):
-                        self.context.IINSERT = True
-                    break
-            if self.context.IINSERT:
-                proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
-
-    def post_exec(self, proxy, compiled, parameters, **kwargs):
-        """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """
-        if getattr(compiled, "isinsert", False):
-            if self.context.IINSERT:
-                proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
-                self.context.IINSERT = False
-            elif self.context.HASIDENT:
-                cursor = proxy("SELECT @@IDENTITY AS lastrowid")
-                row = cursor.fetchone()
-                self.context.last_inserted_ids = [row[0]]
-            self.context.HASIDENT = False
-            
+          
     def dbapi(self):
         return self.module
 
-    def reflecttable(self, table):
+    def reflecttable(self, connection, table):
         import sqlalchemy.databases.information_schema as ischema
         
         # Get base columns
@@ -326,12 +339,12 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
         else:
             current_schema = self.get_default_schema_name()
 
-        columns = ischema.gen_columns.toengine(self)
+        columns = ischema.columns.toengine(self)
         s = select([columns],
                    current_schema and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name,
                    order_by=[columns.c.ordinal_position])
         
-        c = s.execute()
+        c = connection.execute(s)
         while True:
             row = c.fetchone()
             if row is None:
@@ -352,7 +365,6 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
                 if a is not None:
                     args.append(a)
                     coltype = ischema_names[type]
-        
             coltype = coltype(*args)
             colargs= []
             if default is not None:
@@ -363,7 +375,9 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
 
         # We also run an sp_columns to check for identity columns:
         # FIXME: note that this only fetches the existence of an identity column, not it's properties like (seed, increment)
-        cursor = table.engine.execute("sp_columns " + table.name, {})
+        #        also, add a check to make sure we specify the schema name of the table
+        # cursor = table.engine.execute("sp_columns " + table.name, {})
+        cursor = connection.execute("sp_columns " + table.name)
         while True:
             row = cursor.fetchone()
             if row is None:
@@ -375,10 +389,10 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
                 ic.sequence = schema.Sequence(ic.name + '_identity')
 
         # Add constraints
-        RR = ischema.gen_ref_constraints.toengine(self)    #information_schema.referential_constraints
-        TC = ischema.gen_constraints.toengine(self)        #information_schema.table_constraints
-        C  = ischema.gen_column_constraints.toengine(self).alias('C') #information_schema.constraint_column_usage: the constrained column 
-        R  = ischema.gen_column_constraints.toengine(self).alias('R') #information_schema.constraint_column_usage: the referenced column
+        RR = ischema.ref_constraints.toengine(self)    #information_schema.referential_constraints
+        TC = ischema.constraints.toengine(self)        #information_schema.table_constraints
+        C  = ischema.column_constraints.toengine(self).alias('C') #information_schema.constraint_column_usage: the constrained column 
+        R  = ischema.column_constraints.toengine(self).alias('R') #information_schema.constraint_column_usage: the referenced column
 
         fromjoin = TC.join(RR, RR.c.constraint_name == TC.c.constraint_name).join(C, C.c.constraint_name == RR.c.constraint_name)
         fromjoin = fromjoin.join(R, R.c.constraint_name == RR.c.unique_constraint_name)
@@ -389,7 +403,7 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
                    from_obj = [fromjoin]
                    )
                
-        c = s.execute()
+        c = connection.execute(s)
 
         while True:
             row = c.fetchone()
@@ -412,8 +426,8 @@ class MSSQLEngine(ansisql.ANSISQLEngine):
 
 
 class MSSQLCompiler(ansisql.ANSICompiler):
-    def __init__(self, *args, **kwargs):
-        super(MSSQLCompiler, self).__init__(*args, **kwargs)
+    def __init__(self, dialect, statement, parameters, **kwargs):
+        super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
         self.tablealiases = {}
 
     def visit_select_precolumns(self, select):
@@ -463,7 +477,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
 
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
-        if column.primary_key and isinstance(column.type, types.Integer):
+        if column.primary_key and isinstance(column.type, sqltypes.Integer):
             if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
                 column.sequence = schema.Sequence(column.name + '_seq')
 
@@ -490,3 +504,8 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_index(self, index):
         self.append("\nDROP INDEX " + index.table.name + "." + index.name)
         self.execute()
+
+class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
+    pass
+
+dialect = MSSQLDialect