]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes use of port for pymssql
authorRick Morrison <rickmorrison@gmail.com>
Sat, 23 Sep 2006 21:42:34 +0000 (21:42 +0000)
committerRick Morrison <rickmorrison@gmail.com>
Sat, 23 Sep 2006 21:42:34 +0000 (21:42 +0000)
Introduces new auto_indentity_insert option
Fixes bug #261

lib/sqlalchemy/databases/mssql.py

index 6c4fa816c80f0c2a4f34df1bf5939fd88f09c8ee..736351a58ef544ecf770f38a85f8b3638a21ba6d 100644 (file)
@@ -59,8 +59,12 @@ except:
         connect = dbmodule.connect
         # pymmsql doesn't have a Binary method.  we use string
         dbmodule.Binary = lambda st: str(st)
-        make_connect_string = lambda keys:  \
-                    [[], keys]
+        def make_connect_string(keys):
+            if keys.get('port'):
+                # pymssql expects port as host:port, not a separate arg
+                keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
+                del keys['port'] 
+            return [[], keys]
         do_commit = True
     except:
         dbmodule = None
@@ -207,19 +211,23 @@ def descriptor():
     ]}
 
 class MSSQLExecutionContext(default.DefaultExecutionContext):
+    def __init__(self, dialect):
+        self.IINSERT = self.HASIDENT = False
+       super(MSSQLExecutionContext, self).__init__(dialect)
+    
     def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
-        """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if needed. """
+        """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if the feature is turned on and needed. """
         if getattr(compiled, "isinsert", False):
             self.IINSERT = False
             self.HASIDENT = False
             for c in compiled.statement.table.c:
                 if hasattr(c,'sequence'):
                     self.HASIDENT = True
-                    if isinstance(parameters, list):
+                    if engine.dialect.auto_identity_insert and isinstance(parameters, list):
                         if parameters[0].has_key(c.name):
                             self.IINSERT = True
-                    elif parameters.has_key(c.name):
-                        self.IINSERT = True
+                        elif parameters.has_key(c.name):
+                            self.IINSERT = True
                     break
             if self.IINSERT:
                 proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
@@ -235,12 +243,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
                 cursor = proxy("SELECT @@IDENTITY AS lastrowid")
                 row = cursor.fetchone()
                 self._last_inserted_ids = [int(row[0])]
-                print "LAST ROW ID", self._last_inserted_ids
+                print "LAST ROW ID", self._last_inserted_ids
             self.HASIDENT = False
 
+
 class MSSQLDialect(ansisql.ANSIDialect):            
-    def __init__(self, module = None, **params):
+    def __init__(self, module=None, auto_identity_insert=False, **params):
         self.module = module or dbmodule
+        self.auto_identity_insert = auto_identity_insert
         ansisql.ANSIDialect.__init__(self, **params)
 
     def create_connect_args(self, url):
@@ -294,6 +304,8 @@ class MSSQLDialect(ansisql.ANSIDialect):
         except Exception, e:
             raise exceptions.SQLError(statement, parameters, e)
 
+
+
     def do_rollback(self, connection):
         """implementations might want to put logic here for turning autocommit on/off, etc."""
         if do_commit:
@@ -353,11 +365,20 @@ class MSSQLDialect(ansisql.ANSIDialect):
     def dbapi(self):
         return self.module
 
+    def uppercase_table(self, t):
+        # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
+        t.name = t.name.upper()
+        if t.schema:
+            t.schema = t.schema.upper()
+        for c in t.columns:
+            c.name = c.name.upper()
+        return t
+
     def has_table(self, connection, tablename):
         import sqlalchemy.databases.information_schema as ischema
 
         current_schema = self.get_default_schema_name()
-        columns = ischema.columns
+        columns = self.uppercase_table(ischema.columns)
         s = sql.select([columns],
                    current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) or columns.c.table_name==tablename,
                    )
@@ -375,7 +396,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         else:
             current_schema = self.get_default_schema_name()
 
-        columns = ischema.columns
+        columns = self.uppercase_table(ischema.columns)
         s = sql.select([columns],
                    current_schema and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name,
                    order_by=[columns.c.ordinal_position])
@@ -429,10 +450,10 @@ class MSSQLDialect(ansisql.ANSIDialect):
                 ic.sequence = schema.Sequence(ic.name + '_identity')
 
         # Add constraints
-        RR = ischema.ref_constraints    #information_schema.referential_constraints
-        TC = ischema.constraints        #information_schema.table_constraints
-        C  = ischema.column_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column 
-        R  = ischema.column_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
+        RR = self.uupercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
+        TC = self.uupercase_table(ischema.constraints)        #information_schema.table_constraints
+        C  = self.uupercase_table(ischema.column_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column 
+        R  = self.uupercase_table(ischema.column_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column
 
         fromjoin = TC.join(RR, RR.c.constraint_name == TC.c.constraint_name).join(C, C.c.constraint_name == RR.c.constraint_name)
         fromjoin = fromjoin.join(R, R.c.constraint_name == RR.c.unique_constraint_name)
@@ -520,7 +541,7 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
 
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
-        if column.primary_key and column.autoincrement and isinstance(column.type, sqltypes.Integer):
+        if column.primary_key and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_key:
             if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
                 column.sequence = schema.Sequence(column.name + '_seq')