]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Better quoting of identifiers when manipulating schemas.
authorJason Kirtland <jek@discorporate.us>
Thu, 19 Jul 2007 19:24:51 +0000 (19:24 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 19 Jul 2007 19:24:51 +0000 (19:24 +0000)
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
test/engine/reflection.py

diff --git a/CHANGES b/CHANGES
index 8aa7f6dd6963437e6714174ab6ad8d8c9eee89d0..6f72ac3820add35906df220612366e7952b658ea 100644 (file)
--- a/CHANGES
+++ b/CHANGES
   - added "explcit" create/drop/execute support for sequences 
     (i.e. you can pass a "connectable" to each of those methods
     on Sequence)
+  - better quoting of identifiers when manipulating schemas
   - standardized the behavior for table reflection where types can't be located;
     NullType is substituted instead, warning is raised.
 - extensions
index 9e9d388c3782f46ddbaffdbdee51b00576887ad0..d39d8ea0860d8478415a9c846663747eb6202c0f 100644 (file)
@@ -865,11 +865,11 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_check_constraint(self, constraint):
         self.append(", \n\t")
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
+            self.append("CONSTRAINT %s " %
+                        self.preparer.format_constraint(constraint))
         self.append(" CHECK (%s)" % constraint.sqltext)
 
     def visit_column_check_constraint(self, constraint):
-        self.append(" ")
         self.append(" CHECK (%s)" % constraint.sqltext)
 
     def visit_primary_key_constraint(self, constraint):
@@ -877,9 +877,9 @@ class ANSISchemaGenerator(ANSISchemaBase):
             return
         self.append(", \n\t")
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
+            self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
         self.append("PRIMARY KEY ")
-        self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+        self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
 
     def visit_foreign_key_constraint(self, constraint):
         if constraint.use_alter and self.dialect.supports_alter():
@@ -893,12 +893,14 @@ class ANSISchemaGenerator(ANSISchemaBase):
         self.execute()
 
     def define_foreign_key(self, constraint):
+        preparer = self.preparer
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
+            self.append("CONSTRAINT %s " %
+                        preparer.format_constraint(constraint))
         self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
-            string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '),
-            self.preparer.format_table(list(constraint.elements)[0].column.table),
-            string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ')
+            ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
+            preparer.format_table(list(constraint.elements)[0].column.table),
+            ', '.join([preparer.format_column(f.column) for f in constraint.elements])
         ))
         if constraint.ondelete is not None:
             self.append(" ON DELETE %s" % constraint.ondelete)
@@ -908,20 +910,22 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_unique_constraint(self, constraint):
         self.append(", \n\t")
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
-        self.append(" UNIQUE ")
-        self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+            self.append("CONSTRAINT %s " %
+                        self.preparer.format_constraint(constraint))
+        self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
 
     def visit_column(self, column):
         pass
 
     def visit_index(self, index):
-        self.append('CREATE ')
+        preparer = self.preparer
+        self.append("CREATE ")
         if index.unique:
-            self.append('UNIQUE ')
-        self.append('INDEX %s ON %s (%s)' \
-                    % (index.name, self.preparer.format_table(index.table),
-                       string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
+            self.append("UNIQUE ")
+        self.append("INDEX %s ON %s (%s)" \
+                    % (preparer.format_index(index),
+                       preparer.format_table(index.table),
+                       string.join([preparer.format_column(c) for c in index.columns], ', ')))
         self.execute()
 
 class ANSISchemaDropper(ANSISchemaBase):
@@ -941,11 +945,13 @@ class ANSISchemaDropper(ANSISchemaBase):
             self.traverse_single(table)
 
     def visit_index(self, index):
-        self.append("\nDROP INDEX " + index.name)
+        self.append("\nDROP INDEX " + self.preparer.format_index(index))
         self.execute()
 
     def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name))
+        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
+            self.preparer.format_table(constraint.table),
+            self.preparer.format_constraint(constraint)))
         self.execute()
 
     def visit_table(self, table):
@@ -991,7 +997,7 @@ class ANSIIdentifierPreparer(object):
 
         return value.replace('"', '""')
 
-    def _quote_identifier(self, value):
+    def quote_identifier(self, value):
         """Quote an identifier.
 
         Subclasses should override this to provide database-dependent
@@ -1031,20 +1037,20 @@ class ANSIIdentifierPreparer(object):
 
     def __generic_obj_format(self, obj, ident):
         if getattr(obj, 'quote', False):
-            return self._quote_identifier(ident)
+            return self.quote_identifier(ident)
         if self.dialect.cache_identifiers:
             case_sens = getattr(obj, 'case_sensitive', None)
             try:
                 return self.__strings[(ident, case_sens)]
             except KeyError:
                 if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
-                    self.__strings[(ident, case_sens)] = self._quote_identifier(ident)
+                    self.__strings[(ident, case_sens)] = self.quote_identifier(ident)
                 else:
                     self.__strings[(ident, case_sens)] = ident
                 return self.__strings[(ident, case_sens)]
         else:
             if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
-                return self._quote_identifier(ident)
+                return self.quote_identifier(ident)
             else:
                 return ident
 
@@ -1063,6 +1069,12 @@ class ANSIIdentifierPreparer(object):
     def format_savepoint(self, savepoint):
         return self.__generic_obj_format(savepoint, savepoint)
 
+    def format_constraint(self, constraint):
+        return self.__generic_obj_format(constraint, constraint.name)
+
+    def format_index(self, index):
+        return self.__generic_obj_format(index, index.name)
+
     def format_table(self, table, use_schema=True, name=None):
         """Prepare a quoted table and schema name."""
 
index d3f49544dcc955aafbfe98e7e8c90519c54c7707..6e5616c0bd699b8654ab0a7b428a3d5030fd620f 100644 (file)
@@ -1069,10 +1069,13 @@ class MySQLDialect(ansisql.ANSIDialect):
     def is_disconnect(self, e):
         return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
 
-    def get_default_schema_name(self):
-        if not hasattr(self, '_default_schema_name'):
-            self._default_schema_name = sql.text("select database()", self).scalar()
-        return self._default_schema_name
+    def get_default_schema_name(self, connection):
+        try:
+            return self._default_schema_name
+        except AttributeError:
+            name = self._default_schema_name = \
+              connection.execute('SELECT DATABASE()').scalar()
+            return name
 
     def has_table(self, connection, table_name, schema=None):
         # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
@@ -1085,7 +1088,10 @@ class MySQLDialect(ansisql.ANSIDialect):
         else:
             st = "DESCRIBE `%s`" % table_name
         try:
-            return connection.execute(st).rowcount > 0
+            rs = connection.execute(st)
+            have = rs.rowcount > 0
+            rs.close()
+            return have
         except exceptions.SQLError, e:
             if e.orig.args[0] == 1146:
                 return False
@@ -1342,11 +1348,15 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
 
 class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_index(self, index):
-        self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
+        self.append("\nDROP INDEX %s ON %s" %
+                    (self.preparer.format_index(index),
+                     self.preparer.format_table(index.table)))
         self.execute()
 
     def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name))
+        self.append("ALTER TABLE %s DROP FOREIGN KEY %s" %
+                    (self.preparer.format_table(constraint.table),
+                     self.preparer.format_constraint(constraint)))
         self.execute()
 
 class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
@@ -1357,8 +1367,7 @@ class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         return RESERVED_WORDS
 
     def _escape_identifier(self, value):
-        #TODO: determine MySQL's escaping rules
-        return value
+        return value.replace('`', '``')
 
     def _fold_identifier_case(self, value):
         #TODO: determine MySQL's case folding rules
index e7abc1f32b875a1a211f3a7e153429e85680e541..f544e359acd7a81cc18ebf55da7935cf51102d1a 100644 (file)
@@ -226,7 +226,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
         return "oid"
 
     def has_table(self, connection, table_name, schema=None):
-        cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {})
+        cursor = connection.execute("PRAGMA table_info(%s)" %
+           self.identifier_preparer.quote_identifier(table_name), {})
         row = cursor.fetchone()
 
         # consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884
index 760d9bbf5ce6ac99288b1ab979be8610e834b18e..78ffd1fdcf248a3482d54a4e9d7d61aee3513ec1 100644 (file)
@@ -478,17 +478,38 @@ class ReflectionTest(PersistTest):
     def testreserved(self):
         # check a table that uses an SQL reserved name doesn't cause an error
         meta = MetaData(testbase.db)
-        table = Table(
-            'select', meta, 
-            Column('col1', Integer, primary_key=True)
-        )
-        table.create()
+        table_a = Table('select', meta, 
+                       Column('not', Integer, primary_key=True),
+                       Column('from', String(12), nullable=False),
+                       UniqueConstraint('from', name='when'))
+        Index('where', table_a.c['from'])
+
+        quoter = meta.bind.dialect.identifier_preparer.quote_identifier
+
+        table_b = Table('false', meta,
+                        Column('create', Integer, primary_key=True),
+                        Column('true', Integer, ForeignKey('select.not')),
+                        CheckConstraint('%s <> 1' % quoter('true'), name='limit'))
+
+        table_c = Table('is', meta,
+                        Column('or', Integer, nullable=False, primary_key=True),
+                        Column('join', Integer, nullable=False, primary_key=True),
+                        PrimaryKeyConstraint('or', 'join', name='to'))
+
+        index_c = Index('else', table_c.c.join)
+
+        #meta.bind.echo = True
+        meta.create_all()
+
+        index_c.drop()
         
         meta2 = MetaData(testbase.db)
         try:
-            table2 = Table('select', meta2, autoload=True)
+            table_a2 = Table('select', meta2, autoload=True)
+            table_b2 = Table('false', meta2, autoload=True)
+            table_c2 = Table('is', meta2, autoload=True)
         finally:
-            table.drop()
+            meta.drop_all()
 
 class CreateDropTest(PersistTest):
     def setUpAll(self):
@@ -581,6 +602,10 @@ class SchemaTest(PersistTest):
     
     @testbase.supported('mysql','postgres')
     def testcreate(self):
+        engine = testbase.db
+        schema = engine.dialect.get_default_schema_name(engine)
+        #engine.echo = True
+
         if testbase.db.name == 'mysql':
             schema = testbase.db.url.database
         else: