From: Jason Kirtland Date: Thu, 19 Jul 2007 19:24:51 +0000 (+0000) Subject: Better quoting of identifiers when manipulating schemas. X-Git-Tag: rel_0_4_6~70 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a72897f9aa65da07e9ff03cdb081cdd639e392fa;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Better quoting of identifiers when manipulating schemas. --- diff --git a/CHANGES b/CHANGES index 8aa7f6dd69..6f72ac3820 100644 --- a/CHANGES +++ b/CHANGES @@ -128,6 +128,7 @@ - added "explcit" create/drop/execute support for sequences (i.e. you can pass a "connectable" to each of those methods on Sequence) + - better quoting of identifiers when manipulating schemas - standardized the behavior for table reflection where types can't be located; NullType is substituted instead, warning is raised. - extensions diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 9e9d388c37..d39d8ea086 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -865,11 +865,11 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_check_constraint(self, constraint): self.append(", \n\t") if constraint.name is not None: - self.append("CONSTRAINT %s " % constraint.name) + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(constraint)) self.append(" CHECK (%s)" % constraint.sqltext) def visit_column_check_constraint(self, constraint): - self.append(" ") self.append(" CHECK (%s)" % constraint.sqltext) def visit_primary_key_constraint(self, constraint): @@ -877,9 +877,9 @@ class ANSISchemaGenerator(ANSISchemaBase): return self.append(", \n\t") if constraint.name is not None: - self.append("CONSTRAINT %s " % constraint.name) + self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter(): @@ -893,12 +893,14 @@ class ANSISchemaGenerator(ANSISchemaBase): self.execute() def define_foreign_key(self, constraint): + preparer = self.preparer if constraint.name is not None: - self.append("CONSTRAINT %s " % constraint.name) + self.append("CONSTRAINT %s " % + preparer.format_constraint(constraint)) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '), - self.preparer.format_table(list(constraint.elements)[0].column.table), - string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ') + ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), + preparer.format_table(list(constraint.elements)[0].column.table), + ', '.join([preparer.format_column(f.column) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -908,20 +910,22 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_unique_constraint(self, constraint): self.append(", \n\t") if constraint.name is not None: - self.append("CONSTRAINT %s " % constraint.name) - self.append(" UNIQUE ") - self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(constraint)) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) def visit_column(self, column): pass def visit_index(self, index): - self.append('CREATE ') + preparer = self.preparer + self.append("CREATE ") if index.unique: - self.append('UNIQUE ') - self.append('INDEX %s ON %s (%s)' \ - % (index.name, self.preparer.format_table(index.table), - string.join([self.preparer.format_column(c) for c in index.columns], ', '))) + self.append("UNIQUE ") + self.append("INDEX %s ON %s (%s)" \ + % (preparer.format_index(index), + preparer.format_table(index.table), + string.join([preparer.format_column(c) for c in index.columns], ', '))) self.execute() class ANSISchemaDropper(ANSISchemaBase): @@ -941,11 +945,13 @@ class ANSISchemaDropper(ANSISchemaBase): self.traverse_single(table) def visit_index(self, index): - self.append("\nDROP INDEX " + index.name) + self.append("\nDROP INDEX " + self.preparer.format_index(index)) self.execute() def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name)) + self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( + self.preparer.format_table(constraint.table), + self.preparer.format_constraint(constraint))) self.execute() def visit_table(self, table): @@ -991,7 +997,7 @@ class ANSIIdentifierPreparer(object): return value.replace('"', '""') - def _quote_identifier(self, value): + def quote_identifier(self, value): """Quote an identifier. Subclasses should override this to provide database-dependent @@ -1031,20 +1037,20 @@ class ANSIIdentifierPreparer(object): def __generic_obj_format(self, obj, ident): if getattr(obj, 'quote', False): - return self._quote_identifier(ident) + return self.quote_identifier(ident) if self.dialect.cache_identifiers: case_sens = getattr(obj, 'case_sensitive', None) try: return self.__strings[(ident, case_sens)] except KeyError: if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): - self.__strings[(ident, case_sens)] = self._quote_identifier(ident) + self.__strings[(ident, case_sens)] = self.quote_identifier(ident) else: self.__strings[(ident, case_sens)] = ident return self.__strings[(ident, case_sens)] else: if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): - return self._quote_identifier(ident) + return self.quote_identifier(ident) else: return ident @@ -1063,6 +1069,12 @@ class ANSIIdentifierPreparer(object): def format_savepoint(self, savepoint): return self.__generic_obj_format(savepoint, savepoint) + def format_constraint(self, constraint): + return self.__generic_obj_format(constraint, constraint.name) + + def format_index(self, index): + return self.__generic_obj_format(index, index.name) + def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index d3f49544dc..6e5616c0bd 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1069,10 +1069,13 @@ class MySQLDialect(ansisql.ANSIDialect): def is_disconnect(self, e): return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055) - def get_default_schema_name(self): - if not hasattr(self, '_default_schema_name'): - self._default_schema_name = sql.text("select database()", self).scalar() - return self._default_schema_name + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self._default_schema_name = \ + connection.execute('SELECT DATABASE()').scalar() + return name def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1085,7 +1088,10 @@ class MySQLDialect(ansisql.ANSIDialect): else: st = "DESCRIBE `%s`" % table_name try: - return connection.execute(st).rowcount > 0 + rs = connection.execute(st) + have = rs.rowcount > 0 + rs.close() + return have except exceptions.SQLError, e: if e.orig.args[0] == 1146: return False @@ -1342,11 +1348,15 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): class MySQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): - self.append("\nDROP INDEX " + index.name + " ON " + index.table.name) + self.append("\nDROP INDEX %s ON %s" % + (self.preparer.format_index(index), + self.preparer.format_table(index.table))) self.execute() def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name)) + self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % + (self.preparer.format_table(constraint.table), + self.preparer.format_constraint(constraint))) self.execute() class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): @@ -1357,8 +1367,7 @@ class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): return RESERVED_WORDS def _escape_identifier(self, value): - #TODO: determine MySQL's escaping rules - return value + return value.replace('`', '``') def _fold_identifier_case(self, value): #TODO: determine MySQL's case folding rules diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index e7abc1f32b..f544e359ac 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -226,7 +226,8 @@ class SQLiteDialect(ansisql.ANSIDialect): return "oid" def has_table(self, connection, table_name, schema=None): - cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {}) + cursor = connection.execute("PRAGMA table_info(%s)" % + self.identifier_preparer.quote_identifier(table_name), {}) row = cursor.fetchone() # consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884 diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 760d9bbf5c..78ffd1fdcf 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -478,17 +478,38 @@ class ReflectionTest(PersistTest): def testreserved(self): # check a table that uses an SQL reserved name doesn't cause an error meta = MetaData(testbase.db) - table = Table( - 'select', meta, - Column('col1', Integer, primary_key=True) - ) - table.create() + table_a = Table('select', meta, + Column('not', Integer, primary_key=True), + Column('from', String(12), nullable=False), + UniqueConstraint('from', name='when')) + Index('where', table_a.c['from']) + + quoter = meta.bind.dialect.identifier_preparer.quote_identifier + + table_b = Table('false', meta, + Column('create', Integer, primary_key=True), + Column('true', Integer, ForeignKey('select.not')), + CheckConstraint('%s <> 1' % quoter('true'), name='limit')) + + table_c = Table('is', meta, + Column('or', Integer, nullable=False, primary_key=True), + Column('join', Integer, nullable=False, primary_key=True), + PrimaryKeyConstraint('or', 'join', name='to')) + + index_c = Index('else', table_c.c.join) + + #meta.bind.echo = True + meta.create_all() + + index_c.drop() meta2 = MetaData(testbase.db) try: - table2 = Table('select', meta2, autoload=True) + table_a2 = Table('select', meta2, autoload=True) + table_b2 = Table('false', meta2, autoload=True) + table_c2 = Table('is', meta2, autoload=True) finally: - table.drop() + meta.drop_all() class CreateDropTest(PersistTest): def setUpAll(self): @@ -581,6 +602,10 @@ class SchemaTest(PersistTest): @testbase.supported('mysql','postgres') def testcreate(self): + engine = testbase.db + schema = engine.dialect.get_default_schema_name(engine) + #engine.echo = True + if testbase.db.name == 'mysql': schema = testbase.db.url.database else: