From 5545aed3a66d7612062b410a9743e8c3f52ddd3c Mon Sep 17 00:00:00 2001 From: Michael Trier Date: Tue, 31 Mar 2009 07:14:44 +0000 Subject: [PATCH] Corrections to 0.6 to fix mssql problems. --- lib/sqlalchemy/dialects/information_schema.py | 191 ------------------ lib/sqlalchemy/dialects/mssql/base.py | 98 +++++---- .../dialects/mssql/information_schema.py | 11 +- lib/sqlalchemy/dialects/mssql/pyodbc.py | 1 + test/dialect/mssql.py | 43 +++- test/engine/reflection.py | 8 +- test/sql/testtypes.py | 1 - 7 files changed, 109 insertions(+), 244 deletions(-) delete mode 100644 lib/sqlalchemy/dialects/information_schema.py diff --git a/lib/sqlalchemy/dialects/information_schema.py b/lib/sqlalchemy/dialects/information_schema.py deleted file mode 100644 index 9a65cca4cd..0000000000 --- a/lib/sqlalchemy/dialects/information_schema.py +++ /dev/null @@ -1,191 +0,0 @@ -import sqlalchemy.sql as sql -import sqlalchemy.exc as exc -from sqlalchemy import select, MetaData, Table, Column, String, Integer -from sqlalchemy.schema import DefaultClause, ForeignKeyConstraint - -ischema = MetaData() - -schemata = Table("schemata", ischema, - Column("catalog_name", String), - Column("schema_name", String), - Column("schema_owner", String), - schema="information_schema") - -tables = Table("tables", ischema, - Column("table_catalog", String), - Column("table_schema", String), - Column("table_name", String), - Column("table_type", String), - schema="information_schema") - -columns = Table("columns", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("is_nullable", Integer), - Column("data_type", String), - Column("ordinal_position", Integer), - Column("character_maximum_length", Integer), - Column("numeric_precision", Integer), - Column("numeric_scale", Integer), - Column("column_default", Integer), - Column("collation_name", String), - schema="information_schema") - -constraints = Table("table_constraints", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("constraint_name", String), - Column("constraint_type", String), - schema="information_schema") - -column_constraints = Table("constraint_column_usage", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - schema="information_schema") - -pg_key_constraints = Table("key_column_usage", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - Column("ordinal_position", Integer), - schema="information_schema") - -#mysql_key_constraints = Table("key_column_usage", ischema, -# Column("table_schema", String), -# Column("table_name", String), -# Column("column_name", String), -# Column("constraint_name", String), -# Column("referenced_table_schema", String), -# Column("referenced_table_name", String), -# Column("referenced_column_name", String), -# schema="information_schema") - -key_constraints = pg_key_constraints - -ref_constraints = Table("referential_constraints", ischema, - Column("constraint_catalog", String), - Column("constraint_schema", String), - Column("constraint_name", String), - Column("unique_constraint_catlog", String), - Column("unique_constraint_schema", String), - Column("unique_constraint_name", String), - Column("match_option", String), - Column("update_rule", String), - Column("delete_rule", String), - schema="information_schema") - -views = Table("views", ischema, - Column("table_catalog", String), - Column("table_schema", String), - Column("table_name", String), - Column("view_definition", String), - Column("check_option", String), - Column("is_updatable", String), - schema="information_schema") - -def table_names(connection, schema): - s = select([tables.c.table_name], tables.c.table_schema==schema) - return [row[0] for row in connection.execute(s)] - - -def reflecttable(connection, table, include_columns, ischema_names): - key_constraints = pg_key_constraints - - if table.schema is not None: - current_schema = table.schema - else: - current_schema = connection.default_schema_name() - - s = select([columns], - sql.and_(columns.c.table_name==table.name, - columns.c.table_schema==current_schema), - order_by=[columns.c.ordinal_position]) - - c = connection.execute(s) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - #print "row! " + repr(row) - # continue - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default) = ( - row[columns.c.column_name], - row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', - row[columns.c.character_maximum_length], - row[columns.c.numeric_precision], - row[columns.c.numeric_scale], - row[columns.c.column_default] - ) - if include_columns and name not in include_columns: - continue - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = ischema_names[type] - #print "coltype " + repr(coltype) + " args " + repr(args) - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(DefaultClause(sql.text(default))) - table.append_column(Column(name, coltype, nullable=nullable, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns - # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys - # wont reflect properly. dont see a way around this based on whats available from information_schema - s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position]) - s.append_column(column_constraints) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position] - c = connection.execute(s) - - fks = {} - while True: - row = c.fetchone() - if row is None: - break - (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = ( - row[colmap[0]], - row[colmap[1]], - row[colmap[2]], - row[colmap[3]], - row[colmap[4]], - row[colmap[5]], - row[colmap[6]] - ) - #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) - if type == 'PRIMARY KEY': - table.primary_key.add(table.c[constrained_column]) - elif type == 'FOREIGN KEY': - try: - fk = fks[constraint_name] - except KeyError: - fk = ([], []) - fks[constraint_name] = fk - if current_schema == referred_schema: - referred_schema = table.schema - if referred_schema is not None: - Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection) - refspec = ".".join([referred_schema, referred_table, referred_column]) - else: - Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - refspec = ".".join([referred_table, referred_column]) - if constrained_column not in fk[0]: - fk[0].append(constrained_column) - if refspec not in fk[1]: - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 13541524a5..98bc716ab6 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -230,7 +230,7 @@ Known Issues import datetime, decimal, inspect, operator, sys, re from sqlalchemy import sql, schema, exc, util -from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions +from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions from sqlalchemy.engine import default, base, reflection from sqlalchemy import types as sqltypes from decimal import Decimal as _python_Decimal @@ -238,8 +238,8 @@ from decimal import Decimal as _python_Decimal import information_schema as ischema MS_2008_VERSION = (10,) -#MS_2005_VERSION = ?? -#MS_2000_VERSION = ?? +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) MSSQL_RESERVED_WORDS = set(['function']) @@ -430,7 +430,9 @@ class MSNText(_StringType, sqltypes.UnicodeText): """ collation = kwargs.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.UnicodeText.__init__(self, None, **kwargs) + length = kwargs.pop('length', None) + sqltypes.UnicodeText.__init__(self, length, **kwargs) + class MSString(_StringType, sqltypes.VARCHAR): @@ -649,6 +651,18 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): def visit_SMALLDATETIME(self, type_): return "SMALLDATETIME" + def visit_string(self, type_): + if type_.convert_unicode: + return self._extend("NVARCHAR", type_) + else: + return self._extend("VARCHAR", type_) + + def visit_text(self, type_): + if type_.convert_unicode: + return self._extend("NTEXT", type_) + else: + return self._extend("TEXT", type_) + def visit_NTEXT(self, type_): return self._extend("NTEXT", type_) @@ -793,6 +807,7 @@ colspecs = { sqltypes.UnicodeText : MSNText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, + MSSmallDateTime: MSSmallDateTime, } ischema_names = { @@ -1112,11 +1127,13 @@ class MSDialect(default.DefaultDialect): return self.schema_name def table_names(self, connection, schema): - return ischema.table_names(connection, schema) + s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema) + return [row[0] for row in connection.execute(s)] + def has_table(self, connection, tablename, schema=None): current_schema = schema or self.get_default_schema_name(connection) - columns = self.uppercase_table(ischema.columns) + columns = ischema.columns s = sql.select([columns], current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) @@ -1128,17 +1145,17 @@ class MSDialect(default.DefaultDialect): return row is not None @reflection.cache - def get_schema_names(self, connection, info_cache=None): - s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name], + def get_schema_names(self, connection): + s = sql.select([ischema.schemata.c.schema_name], order_by=[ischema.schemata.c.schema_name] ) schema_names = [r[0] for r in connection.execute(s)] return schema_names @reflection.cache - def get_table_names(self, connection, schemaname, info_cache=None): - current_schema = schemaname or self.get_default_schema_name(connection) - tables = self.uppercase_table(ischema.tables) + def get_table_names(self, connection, schema=None, **kw): + current_schema = schema or self.get_default_schema_name(connection) + tables = ischema.tables s = sql.select([tables.c.table_name], sql.and_( tables.c.table_schema == current_schema, @@ -1150,9 +1167,9 @@ class MSDialect(default.DefaultDialect): return table_names @reflection.cache - def get_view_names(self, connection, schemaname=None, info_cache=None): - current_schema = schemaname or self.get_default_schema_name(connection) - tables = self.uppercase_table(ischema.tables) + def get_view_names(self, connection, schema=None, **kw): + current_schema = schema or self.get_default_schema_name(connection) + tables = ischema.tables s = sql.select([tables.c.table_name], sql.and_( tables.c.table_schema == current_schema, @@ -1164,9 +1181,8 @@ class MSDialect(default.DefaultDialect): return view_names @reflection.cache - def get_indexes(self, connection, tablename, schemaname=None, - info_cache=None): - current_schema = schemaname or self.get_default_schema_name(connection) + def get_indexes(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.get_default_schema_name(connection) full_tname = "%s.%s" % (current_schema, tablename) indexes = [] s = sql.text("exec sp_helpindex '%s'" % full_tname) @@ -1181,10 +1197,9 @@ class MSDialect(default.DefaultDialect): return indexes @reflection.cache - def get_view_definition(self, connection, viewname, schemaname=None, - info_cache=None): - current_schema = schemaname or self.get_default_schema_name(connection) - views = self.uppercase_table(ischema.views) + def get_view_definition(self, connection, viewname, schema=None, **kw): + current_schema = schema or self.get_default_schema_name(connection) + views = ischema.views s = sql.select([views.c.view_definition], sql.and_( views.c.table_schema == current_schema, @@ -1197,11 +1212,10 @@ class MSDialect(default.DefaultDialect): return view_def @reflection.cache - def get_columns(self, connection, tablename, schemaname=None, - info_cache=None): + def get_columns(self, connection, tablename, schema=None, **kw): # Get base columns current_schema = schemaname or self.get_default_schema_name(connection) - columns = self.uppercase_table(ischema.columns) + columns = ischema.columns s = sql.select([columns], current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) @@ -1257,15 +1271,14 @@ class MSDialect(default.DefaultDialect): return cols @reflection.cache - def get_primary_keys(self, connection, tablename, schemaname=None, - info_cache=None): - current_schema = schemaname or self.get_default_schema_name(connection) + def get_primary_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.get_default_schema_name(connection) pkeys = [] # Add constraints - RR = self.uppercase_table(ischema.ref_constraints) #information_schema.referential_constraints - TC = self.uppercase_table(ischema.constraints) #information_schema.table_constraints - C = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column - R = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column # Primary key constraints s = sql.select([C.c.column_name, TC.c.constraint_type], @@ -1280,14 +1293,13 @@ class MSDialect(default.DefaultDialect): return pkeys @reflection.cache - def get_foreign_keys(self, connection, tablename, schemaname=None, - info_cache=None): - current_schema = schemaname or self.get_default_schema_name(connection) + def get_foreign_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.get_default_schema_name(connection) # Add constraints - RR = self.uppercase_table(ischema.ref_constraints) #information_schema.referential_constraints - TC = self.uppercase_table(ischema.constraints) #information_schema.table_constraints - C = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column - R = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column # Foreign key constraints s = sql.select([C.c.column_name, @@ -1337,9 +1349,8 @@ class MSDialect(default.DefaultDialect): current_schema = table.schema else: current_schema = self.get_default_schema_name(connection) - info_cache = MSInfoCache() - columns = self.get_columns(connection, table.name, current_schema, - info_cache) + columns = self.get_columns(connection, table.name, current_schema) + found_table = False for cdict in columns: name = cdict['name'] @@ -1384,7 +1395,7 @@ class MSDialect(default.DefaultDialect): # Primary key constraints pkeys = self.get_primary_keys(connection, table.name, - current_schema, info_cache) + current_schema) for pkey in pkeys: if pkey in table.c: table.primary_key.add(table.c[pkey]) @@ -1396,8 +1407,7 @@ class MSDialect(default.DefaultDialect): else: return '.'.join([rschema, rtbl, rcol]) - fkeys = self.get_foreign_keys(connection, table.name, current_schema, - info_cache) + fkeys = self.get_foreign_keys(connection, table.name, current_schema) for fkey_d in fkeys: fknm = fkey_d['name'] scols = fkey_d['constrained_columns'] diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 447c7b5700..84269bf181 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -1,7 +1,7 @@ from sqlalchemy import Table, MetaData, Column, ForeignKey, String, Integer ischema = MetaData() - + schemata = Table("SCHEMATA", ischema, Column("CATALOG_NAME", String, key="catalog_name"), Column("SCHEMA_NAME", String, key="schema_name"), @@ -63,3 +63,12 @@ ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, Column("DELETE_RULE", String, key="delete_rule"), schema="INFORMATION_SCHEMA") +views = Table("VIEWS", ischema, + Column("TABLE_CATALOG", String), + Column("TABLE_SCHEMA", String), + Column("TABLE_NAME", String), + Column("VIEW_DEFINITION", String), + Column("CHECK_OPTION", String), + Column("IS_UPDATABLE", String), + schema="INFORMATION_SCHEMA") + diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 1b67cc04c4..13180ec609 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -32,6 +32,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): supports_sane_multi_rowcount = False # PyODBC unicode is broken on UCS-4 builds supports_unicode = sys.maxunicode == 65535 + supports_unicode_binds = supports_unicode supports_unicode_statements = supports_unicode execution_ctx_cls = MSExecutionContext_pyodbc diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index 3ce8a9220b..ecaf7b5a6d 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -181,9 +181,50 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): self.assertEqual([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) -class ReflectionTest(TestBase): +class ReflectionTest(TestBase, ComparesTables): __only_on__ = 'mssql' + def test_basic_reflection(self): + meta = MetaData(testing.db) + + users = Table('engine_users', meta, + Column('user_id', types.INT, primary_key=True), + Column('user_name', types.VARCHAR(20), nullable=False), + Column('test1', types.CHAR(5), nullable=False), + Column('test2', types.Float(5), nullable=False), + Column('test3', types.Text), + Column('test4', types.Numeric, nullable = False), + Column('test5', types.DateTime), + Column('parent_user_id', types.Integer, + ForeignKey('engine_users.user_id')), + Column('test6', types.DateTime, nullable=False), + Column('test7', types.Text), + Column('test8', types.Binary), + Column('test_passivedefault2', types.Integer, server_default='5'), + Column('test9', types.Binary(100)), + Column('test_numeric', types.Numeric()), + test_needs_fk=True, + ) + + addresses = Table('engine_email_addresses', meta, + Column('address_id', types.Integer, primary_key = True), + Column('remote_user_id', types.Integer, ForeignKey(users.c.user_id)), + Column('email_address', types.String(20)), + test_needs_fk=True, + ) + meta.create_all() + + try: + meta2 = MetaData() + reflected_users = Table('engine_users', meta2, autoload=True, + autoload_with=testing.db) + reflected_addresses = Table('engine_email_addresses', meta2, + autoload=True, autoload_with=testing.db) + self.assert_tables_equal(users, reflected_users) + self.assert_tables_equal(addresses, reflected_addresses) + finally: + meta.drop_all() + def testidentity(self): meta = MetaData(testing.db) table = Table( diff --git a/test/engine/reflection.py b/test/engine/reflection.py index f133638fd5..0bd1d7b9e3 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -12,6 +12,7 @@ metadata, users = None, None class ReflectionTest(TestBase, ComparesTables): + @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+') @testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely') def test_basic_reflection(self): meta = MetaData(testing.db) @@ -801,12 +802,7 @@ class HasSequenceTest(TestBase): # Tests related to engine.reflection def get_schema(): -# if testing.against('sqlite'): -# return None - if testing.against('oracle'): - return 'test' - else: - return 'test_schema' + return 'alt_schema' def createTables(meta, schema=None): if schema: diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 1a78e463ae..74883cafa3 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -411,7 +411,6 @@ class BinaryTest(TestBase, AssertsExecutionResults): def tearDownAll(self): binary_table.drop() - @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00') def test_round_trip(self): testobj1 = pickleable.Foo('im foo 1') testobj2 = pickleable.Foo('im foo 2') -- 2.47.3