From cce1f073e3ed924b83bfe15ae86ccf0d772b15c1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 1 Jul 2006 19:26:30 +0000 Subject: [PATCH] got MS-SQL support largely working, including reflection, basic types, fair amount of ORM stuff, etc. 'rowcount' label is reseved in MS-SQL and had to change in sql.py count() as well as orm.query --- CHANGES | 1 + lib/sqlalchemy/databases/mssql.py | 58 ++++++++++++++++++++++--------- lib/sqlalchemy/engine/base.py | 4 ++- lib/sqlalchemy/engine/default.py | 1 + lib/sqlalchemy/orm/query.py | 2 +- lib/sqlalchemy/sql.py | 4 +-- test/orm/objectstore.py | 14 +++++--- test/testbase.py | 2 +- 8 files changed, 59 insertions(+), 27 deletions(-) diff --git a/CHANGES b/CHANGES index a36de9ea63..04a841ed67 100644 --- a/CHANGES +++ b/CHANGES @@ -5,6 +5,7 @@ two mappers that referenced each other working around new setuptools PYTHONPATH-killing behavior - further fixes with attributes/dependencies/etc.... - improved error handling for when DynamicMetaData is not connected +- MS-SQL support largely working (tested with pymssql) 0.2.4 - try/except when the mapper sets init.__name__ on a mapped class, diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 181069a953..89cc883989 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -52,6 +52,7 @@ try: [["Provider=SQLOLEDB;Data Source=%s;User Id=%s;Password=%s;Initial Catalog=%s" % ( keys["host"], keys["user"], keys["password"], keys["database"])], {}] do_commit = False + sane_rowcount = True except: try: import pymssql as dbmodule @@ -64,6 +65,7 @@ except: except: dbmodule = None make_connect_string = lambda keys: [[],{}] + sane_rowcount = False class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): @@ -195,12 +197,16 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): for c in compiled.statement.table.c: if hasattr(c,'sequence'): self.HASIDENT = True - if parameters.has_key(c.name): + if isinstance(parameters, list): + if parameters[0].has_key(c.name): + self.IINSERT = True + elif parameters.has_key(c.name): self.IINSERT = True break if self.IINSERT: proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name) - + super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs) + def post_exec(self, engine, proxy, compiled, parameters, **kwargs): """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """ if getattr(compiled, "isinsert", False): @@ -210,7 +216,8 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): elif self.HASIDENT: cursor = proxy("SELECT @@IDENTITY AS lastrowid") row = cursor.fetchone() - self.last_inserted_ids = [row[0]] + self._last_inserted_ids = [int(row[0])] + print "LAST ROW ID", self._last_inserted_ids self.HASIDENT = False class MSSQLDialect(ansisql.ANSIDialect): @@ -236,7 +243,7 @@ class MSSQLDialect(ansisql.ANSIDialect): return self.context.last_inserted_ids def supports_sane_rowcount(self): - return True + return sane_rowcount def compiler(self, statement, bindparams, **kwargs): return MSSQLCompiler(self, statement, bindparams, **kwargs) @@ -328,6 +335,19 @@ class MSSQLDialect(ansisql.ANSIDialect): def dbapi(self): return self.module + def has_table(self, connection, tablename): + import sqlalchemy.databases.information_schema as ischema + + current_schema = self.get_default_schema_name() + columns = ischema.columns + s = sql.select([columns], + current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) or columns.c.table_name==tablename, + ) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + def reflecttable(self, connection, table): import sqlalchemy.databases.information_schema as ischema @@ -338,7 +358,7 @@ class MSSQLDialect(ansisql.ANSIDialect): current_schema = self.get_default_schema_name() columns = ischema.columns - s = select([columns], + s = sql.select([columns], current_schema and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name, order_by=[columns.c.ordinal_position]) @@ -363,11 +383,11 @@ class MSSQLDialect(ansisql.ANSIDialect): for a in (charlen, numericprec, numericscale): if a is not None: args.append(a) - coltype = ischema_names[type] + coltype = ischema_names[type] coltype = coltype(*args) colargs= [] if default is not None: - colargs.append(PassiveDefault(sql.text(default))) + colargs.append(schema.PassiveDefault(sql.text(default))) table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) @@ -386,11 +406,12 @@ class MSSQLDialect(ansisql.ANSIDialect): col_name, type_name = row[3], row[5] if type_name.endswith("identity"): ic = table.c[col_name] + ic.primary_key = True # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute ic.sequence = schema.Sequence(ic.name + '_identity') # Add constraints - RR = ischema.ref_constraints(self) #information_schema.referential_constraints + RR = ischema.ref_constraints #information_schema.referential_constraints TC = ischema.constraints #information_schema.table_constraints C = ischema.column_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column R = ischema.column_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column @@ -398,11 +419,12 @@ class MSSQLDialect(ansisql.ANSIDialect): fromjoin = TC.join(RR, RR.c.constraint_name == TC.c.constraint_name).join(C, C.c.constraint_name == RR.c.constraint_name) fromjoin = fromjoin.join(R, R.c.constraint_name == RR.c.unique_constraint_name) - s = select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name, + s = sql.select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name, R.c.table_schema, R.c.table_name, R.c.column_name], - and_(RR.c.constraint_schema == current_schema, C.c.table_name == table.name), - from_obj = [fromjoin] + sql.and_(RR.c.constraint_schema == current_schema, C.c.table_name == table.name), + from_obj = [fromjoin], use_labels=True ) + colmap = [TC.c.constraint_type, C.c.column_name, R.c.table_schema, R.c.table_name, R.c.column_name] c = connection.execute(s) @@ -410,20 +432,22 @@ class MSSQLDialect(ansisql.ANSIDialect): row = c.fetchone() if row is None: break + print "CCROW", row.keys(), row (type, constrained_column, referred_schema, referred_table, referred_column) = ( row[colmap[0]], + row[colmap[1]], + row[colmap[2]], row[colmap[3]], - row[colmap[4]], - row[colmap[5]], - row[colmap[6]] + row[colmap[4]] ) if type=='PRIMARY KEY': table.c[constrained_column]._set_primary_key() elif type=='FOREIGN KEY': - remotetable = Table(referred_table, self, autoload = True, schema=referred_schema) + if current_schema == referred_schema: + referred_schema = table.schema + remotetable = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema) table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column])) - class MSSQLCompiler(ansisql.ANSICompiler): @@ -470,7 +494,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): super(MSSQLCompiler, self).visit_column(column) if column.table is not None and self.tablealiases.has_key(column.table): self.strings[column] = \ - self.strings[self.tablealiases[column.table].corresponding_column(column.original)] + self.strings[self.tablealiases[column.table].corresponding_column(column)] class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d9e3f4ed83..85e68825f2 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -514,8 +514,10 @@ class ResultProxy: class AmbiguousColumn(object): def __init__(self, key): self.key = key + def dialect_impl(self, dialect): + return self def convert_result_value(self, arg, engine): - raise InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) + raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None): """ResultProxy objects are constructed via the execute() method on SQLEngine.""" diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 7ebce0c222..e318b6756b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -194,6 +194,7 @@ class DefaultExecutionContext(base.ExecutionContext): self._last_inserted_ids = None else: self._last_inserted_ids = last_inserted_ids + print "LAST INSERTED PARAMS", param self._last_inserted_params = param elif getattr(compiled, 'isupdate', False): if isinstance(parameters, list): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d27e23d846..985659eec5 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -332,7 +332,7 @@ class Query(object): # raise "ok first thing", str(s2) if not kwargs.get('distinct', False) and order_by: s2.order_by(*util.to_list(order_by)) - s3 = s2.alias('rowcount') + s3 = s2.alias('tbl_row_count') crit = [] for i in range(0, len(self.table.primary_key)): crit.append(s3.primary_key[i] == self.table.primary_key[i]) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index d978ee208e..7b17927f08 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -687,7 +687,7 @@ class FromClause(Selectable): col = self.primary_key[0] else: col = list(self.columns)[0] - return select([func.count(col).label('rowcount')], whereclause, from_obj=[self], **params) + return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) def join(self, right, *args, **kwargs): return Join(self, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): @@ -1283,7 +1283,7 @@ class TableClause(FromClause): col = self.primary_key[0] else: col = list(self.columns)[0] - return select([func.count(col).label('rowcount')], whereclause, from_obj=[self], **params) + return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) def join(self, right, *args, **kwargs): return Join(self, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): diff --git a/test/orm/objectstore.py b/test/orm/objectstore.py index 18c64e9365..c2ef112a64 100644 --- a/test/orm/objectstore.py +++ b/test/orm/objectstore.py @@ -132,7 +132,7 @@ class VersioningTest(SessionTest): version_table.delete().execute() SessionTest.tearDown(self) - @testbase.unsupported('mysql') + @testbase.unsupported('mysql', 'mssql') def testbasic(self): s = create_session() class Foo(object):pass @@ -227,6 +227,7 @@ class UnicodeTest(SessionTest): assert len(t1.t2s) == 2 class PKTest(SessionTest): + @testbase.unsupported('mssql') def setUpAll(self): SessionTest.setUpAll(self) db.echo = False @@ -234,19 +235,19 @@ class PKTest(SessionTest): global table2 global table3 table = Table( - 'multi', db, + 'multipk', db, Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True), Column('multi_rev', Integer, primary_key=True), Column('name', String(50), nullable=False), Column('value', String(100)) ) - table2 = Table('multi2', db, + table2 = Table('multipk2', db, Column('pk_col_1', String(30), primary_key=True), Column('pk_col_2', String(30), primary_key=True), Column('data', String(30), ) ) - table3 = Table('multi3', db, + table3 = Table('multipk3', db, Column('pri_code', String(30), key='primary', primary_key=True), Column('sec_code', String(30), key='secondary', primary_key=True), Column('date_assigned', Date, key='assigned', primary_key=True), @@ -256,6 +257,7 @@ class PKTest(SessionTest): table2.create() table3.create() db.echo = testbase.echo + @testbase.unsupported('mssql') def tearDownAll(self): db.echo = False table.drop() @@ -264,7 +266,7 @@ class PKTest(SessionTest): db.echo = testbase.echo SessionTest.tearDownAll(self) - @testbase.unsupported('sqlite') + @testbase.unsupported('sqlite', 'mssql') def testprimarykey(self): class Entry(object): pass @@ -277,6 +279,7 @@ class PKTest(SessionTest): ctx.current.clear() e2 = Entry.mapper.get((e.multi_id, 2)) self.assert_(e is not e2 and e._instance_key == e2._instance_key) + @testbase.unsupported('mssql') def testmanualpk(self): class Entry(object): pass @@ -286,6 +289,7 @@ class PKTest(SessionTest): e.pk_col_2 = 'pk1_related' e.data = 'im the data' ctx.current.flush() + @testbase.unsupported('mssql') def testkeypks(self): import datetime class Entity(object): diff --git a/test/testbase.py b/test/testbase.py index 8fbd6954c9..ddec64179c 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -72,7 +72,7 @@ def parse_argv(): db_uri = 'oracle://scott:tiger@127.0.0.1:1521' opts = {'use_ansi':False} elif DBTYPE == 'mssql': - db_uri = 'mssql://scott:tiger@/test' + db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test' if not db_uri: raise "Could not create engine. specify --db to test runner." -- 2.47.3