]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got MS-SQL support largely working, including reflection, basic types, fair amount...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Jul 2006 19:26:30 +0000 (19:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Jul 2006 19:26:30 +0000 (19:26 +0000)
'rowcount' label is reseved in MS-SQL and had to change in sql.py count() as well as orm.query

CHANGES
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql.py
test/orm/objectstore.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index a36de9ea636aedcca899615c88cd18263b10f3ce..04a841ed674c2b39cbeacdc037597a17270fdcef 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,6 +5,7 @@ two mappers that referenced each other
 working around new setuptools PYTHONPATH-killing behavior
 - further fixes with attributes/dependencies/etc....
 - improved error handling for when DynamicMetaData is not connected
+- MS-SQL support largely working (tested with pymssql)
 
 0.2.4
 - try/except when the mapper sets init.__name__ on a mapped class,
index 181069a953ea91bc6501a6d775dfd3f72a582c43..89cc883989a7f83fca5bba862643e323fb678e57 100644 (file)
@@ -52,6 +52,7 @@ try:
         [["Provider=SQLOLEDB;Data Source=%s;User Id=%s;Password=%s;Initial Catalog=%s" % (
             keys["host"], keys["user"], keys["password"], keys["database"])], {}]
     do_commit = False
+    sane_rowcount = True
 except:
     try:
         import pymssql as dbmodule
@@ -64,6 +65,7 @@ except:
     except:
         dbmodule = None
         make_connect_string = lambda keys: [[],{}]
+    sane_rowcount = False
     
 class MSNumeric(sqltypes.Numeric):
     def convert_result_value(self, value, dialect):
@@ -195,12 +197,16 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
             for c in compiled.statement.table.c:
                 if hasattr(c,'sequence'):
                     self.HASIDENT = True
-                    if parameters.has_key(c.name):
+                    if isinstance(parameters, list):
+                        if parameters[0].has_key(c.name):
+                            self.IINSERT = True
+                    elif parameters.has_key(c.name):
                         self.IINSERT = True
                     break
             if self.IINSERT:
                 proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
-
+       super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs)
+       
     def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
         """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """
         if getattr(compiled, "isinsert", False):
@@ -210,7 +216,8 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
             elif self.HASIDENT:
                 cursor = proxy("SELECT @@IDENTITY AS lastrowid")
                 row = cursor.fetchone()
-                self.last_inserted_ids = [row[0]]
+                self._last_inserted_ids = [int(row[0])]
+                print "LAST ROW ID", self._last_inserted_ids
             self.HASIDENT = False
 
 class MSSQLDialect(ansisql.ANSIDialect):            
@@ -236,7 +243,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         return self.context.last_inserted_ids
 
     def supports_sane_rowcount(self):
-        return True
+        return sane_rowcount
 
     def compiler(self, statement, bindparams, **kwargs):
         return MSSQLCompiler(self, statement, bindparams, **kwargs)
@@ -328,6 +335,19 @@ class MSSQLDialect(ansisql.ANSIDialect):
     def dbapi(self):
         return self.module
 
+    def has_table(self, connection, tablename):
+        import sqlalchemy.databases.information_schema as ischema
+
+        current_schema = self.get_default_schema_name()
+        columns = ischema.columns
+        s = sql.select([columns],
+                   current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) or columns.c.table_name==tablename,
+                   )
+        
+        c = connection.execute(s)
+        row  = c.fetchone()
+        return row is not None
+        
     def reflecttable(self, connection, table):
         import sqlalchemy.databases.information_schema as ischema
         
@@ -338,7 +358,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
             current_schema = self.get_default_schema_name()
 
         columns = ischema.columns
-        s = select([columns],
+        s = sql.select([columns],
                    current_schema and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name,
                    order_by=[columns.c.ordinal_position])
         
@@ -363,11 +383,11 @@ class MSSQLDialect(ansisql.ANSIDialect):
             for a in (charlen, numericprec, numericscale):
                 if a is not None:
                     args.append(a)
-                    coltype = ischema_names[type]
+            coltype = ischema_names[type]
             coltype = coltype(*args)
             colargs= []
             if default is not None:
-                colargs.append(PassiveDefault(sql.text(default)))
+                colargs.append(schema.PassiveDefault(sql.text(default)))
                 
             table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
         
@@ -386,11 +406,12 @@ class MSSQLDialect(ansisql.ANSIDialect):
             col_name, type_name = row[3], row[5]
             if type_name.endswith("identity"):
                 ic = table.c[col_name]
+                ic.primary_key = True
                 # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
                 ic.sequence = schema.Sequence(ic.name + '_identity')
 
         # Add constraints
-        RR = ischema.ref_constraints(self)    #information_schema.referential_constraints
+        RR = ischema.ref_constraints    #information_schema.referential_constraints
         TC = ischema.constraints        #information_schema.table_constraints
         C  = ischema.column_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column 
         R  = ischema.column_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
@@ -398,11 +419,12 @@ class MSSQLDialect(ansisql.ANSIDialect):
         fromjoin = TC.join(RR, RR.c.constraint_name == TC.c.constraint_name).join(C, C.c.constraint_name == RR.c.constraint_name)
         fromjoin = fromjoin.join(R, R.c.constraint_name == RR.c.unique_constraint_name)
 
-        s = select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name,
+        s = sql.select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name,
                     R.c.table_schema, R.c.table_name, R.c.column_name],
-                   and_(RR.c.constraint_schema == current_schema,  C.c.table_name == table.name),
-                   from_obj = [fromjoin]
+                   sql.and_(RR.c.constraint_schema == current_schema,  C.c.table_name == table.name),
+                   from_obj = [fromjoin], use_labels=True
                    )
+        colmap = [TC.c.constraint_type, C.c.column_name, R.c.table_schema, R.c.table_name, R.c.column_name]
                
         c = connection.execute(s)
 
@@ -410,20 +432,22 @@ class MSSQLDialect(ansisql.ANSIDialect):
             row = c.fetchone()
             if row is None:
                 break
+            print "CCROW", row.keys(), row
             (type, constrained_column, referred_schema, referred_table, referred_column) = (
                 row[colmap[0]],
+                row[colmap[1]],
+                row[colmap[2]],
                 row[colmap[3]],
-                row[colmap[4]],
-                row[colmap[5]],
-                row[colmap[6]]
+                row[colmap[4]]
                 )
 
             if type=='PRIMARY KEY':
                 table.c[constrained_column]._set_primary_key()
             elif type=='FOREIGN KEY':
-                remotetable = Table(referred_table, self, autoload = True, schema=referred_schema)
+                if current_schema == referred_schema:
+                    referred_schema = table.schema
+                remotetable = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema)
                 table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column]))
-        
 
 
 class MSSQLCompiler(ansisql.ANSICompiler):
@@ -470,7 +494,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
         super(MSSQLCompiler, self).visit_column(column)
         if column.table is not None and self.tablealiases.has_key(column.table):
             self.strings[column] = \
-                self.strings[self.tablealiases[column.table].corresponding_column(column.original)]
+                self.strings[self.tablealiases[column.table].corresponding_column(column)]
 
         
 class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
index d9e3f4ed835e3e743bf5b2bb42d5e312df92a145..85e68825f2b14c6917ec31fc52922a7122f31ea6 100644 (file)
@@ -514,8 +514,10 @@ class ResultProxy:
     class AmbiguousColumn(object):
         def __init__(self, key):
             self.key = key
+        def dialect_impl(self, dialect):
+            return self
         def convert_result_value(self, arg, engine):
-            raise InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
+            raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
     
     def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None):
         """ResultProxy objects are constructed via the execute() method on SQLEngine."""
index 7ebce0c22223c17403987273afed23139f11d35d..e318b6756b0292f0c025508041942a2a314cff28 100644 (file)
@@ -194,6 +194,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                     self._last_inserted_ids = None
                 else:
                     self._last_inserted_ids = last_inserted_ids
+                print "LAST INSERTED PARAMS", param
                 self._last_inserted_params = param
         elif getattr(compiled, 'isupdate', False):
             if isinstance(parameters, list):
index d27e23d846988ad4ead995a4818014ffdf3a6882..985659eec5274739bf80676c05bb120eccb41054 100644 (file)
@@ -332,7 +332,7 @@ class Query(object):
 #            raise "ok first thing", str(s2)
             if not kwargs.get('distinct', False) and order_by:
                 s2.order_by(*util.to_list(order_by))
-            s3 = s2.alias('rowcount')
+            s3 = s2.alias('tbl_row_count')
             crit = []
             for i in range(0, len(self.table.primary_key)):
                 crit.append(s3.primary_key[i] == self.table.primary_key[i])
index d978ee208ee8156c9954e17217a9894dc0160617..7b17927f0896b57583b1dd8d1cebc3701a062542 100644 (file)
@@ -687,7 +687,7 @@ class FromClause(Selectable):
             col = self.primary_key[0]
         else:
             col = list(self.columns)[0]
-        return select([func.count(col).label('rowcount')], whereclause, from_obj=[self], **params)
+        return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
     def join(self, right, *args, **kwargs):
         return Join(self, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):
@@ -1283,7 +1283,7 @@ class TableClause(FromClause):
             col = self.primary_key[0]
         else:
             col = list(self.columns)[0]
-        return select([func.count(col).label('rowcount')], whereclause, from_obj=[self], **params)
+        return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
     def join(self, right, *args, **kwargs):
         return Join(self, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):
index 18c64e93650839dbd8173d20e3c95321efc439f5..c2ef112a649d717f2bdeb2008636c1e247586772 100644 (file)
@@ -132,7 +132,7 @@ class VersioningTest(SessionTest):
         version_table.delete().execute()
         SessionTest.tearDown(self)
     
-    @testbase.unsupported('mysql')
+    @testbase.unsupported('mysql', 'mssql')
     def testbasic(self):
         s = create_session()
         class Foo(object):pass
@@ -227,6 +227,7 @@ class UnicodeTest(SessionTest):
         assert len(t1.t2s) == 2
         
 class PKTest(SessionTest):
+    @testbase.unsupported('mssql')
     def setUpAll(self):
         SessionTest.setUpAll(self)
         db.echo = False
@@ -234,19 +235,19 @@ class PKTest(SessionTest):
         global table2
         global table3
         table = Table(
-            'multi', db, 
+            'multipk', db, 
             Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True),
             Column('multi_rev', Integer, primary_key=True),
             Column('name', String(50), nullable=False),
             Column('value', String(100))
         )
         
-        table2 = Table('multi2', db,
+        table2 = Table('multipk2', db,
             Column('pk_col_1', String(30), primary_key=True),
             Column('pk_col_2', String(30), primary_key=True),
             Column('data', String(30), )
             )
-        table3 = Table('multi3', db,
+        table3 = Table('multipk3', db,
             Column('pri_code', String(30), key='primary', primary_key=True),
             Column('sec_code', String(30), key='secondary', primary_key=True),
             Column('date_assigned', Date, key='assigned', primary_key=True),
@@ -256,6 +257,7 @@ class PKTest(SessionTest):
         table2.create()
         table3.create()
         db.echo = testbase.echo
+    @testbase.unsupported('mssql')
     def tearDownAll(self):
         db.echo = False
         table.drop()
@@ -264,7 +266,7 @@ class PKTest(SessionTest):
         db.echo = testbase.echo
         SessionTest.tearDownAll(self)
         
-    @testbase.unsupported('sqlite')
+    @testbase.unsupported('sqlite', 'mssql')
     def testprimarykey(self):
         class Entry(object):
             pass
@@ -277,6 +279,7 @@ class PKTest(SessionTest):
         ctx.current.clear()
         e2 = Entry.mapper.get((e.multi_id, 2))
         self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+    @testbase.unsupported('mssql')
     def testmanualpk(self):
         class Entry(object):
             pass
@@ -286,6 +289,7 @@ class PKTest(SessionTest):
         e.pk_col_2 = 'pk1_related'
         e.data = 'im the data'
         ctx.current.flush()
+    @testbase.unsupported('mssql')
     def testkeypks(self):
         import datetime
         class Entity(object):
index 8fbd6954c9e275fc4bc63f1b10d0e89f7c501dc0..ddec64179c817fedc5aa42bac8a8435966f2823d 100644 (file)
@@ -72,7 +72,7 @@ def parse_argv():
             db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
             opts = {'use_ansi':False}
         elif DBTYPE == 'mssql':
-            db_uri = 'mssql://scott:tiger@/test'
+            db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test'
 
     if not db_uri:
         raise "Could not create engine.  specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql> to test runner."