]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Better quoting of identifiers when manipulating schemas
authorJason Kirtland <jek@discorporate.us>
Thu, 19 Jul 2007 20:44:19 +0000 (20:44 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 19 Jul 2007 20:44:19 +0000 (20:44 +0000)
Merged from r2981

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
test/engine/reflection.py

diff --git a/CHANGES b/CHANGES
index 7f69e89e523c29f53244de823f4a95e1daad4ba1..f61d453da0268d150cef28fd96236ba3ced3c111 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,5 +1,6 @@
 0.3.10
 - sql
+    - better quoting of identifiers when manipulating schemas
     - got connection-bound metadata to work with implicit execution
     - foreign key specs can have any chararcter in their identifiers
      [ticket:667]
@@ -7,6 +8,8 @@
       each other, improves ORM lazy load optimization [ticket:664]
 - orm
     - cleanup to connection-bound sessions, SessionTransaction
+- mysql
+    - fixed issue with tables in alternate schemas [ticket:662]
 - postgres
     - fixed max identifier length (63) [ticket:571]
 
index a0f37e1707d1bb4a72e6da1593613263c6dce7c6..9994d528895e9db1f8fe8524ebb3a979ee3ccd7a 100644 (file)
@@ -856,11 +856,11 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_check_constraint(self, constraint):
         self.append(", \n\t")
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
+            self.append("CONSTRAINT %s " %
+                        self.preparer.format_constraint(constraint))
         self.append(" CHECK (%s)" % constraint.sqltext)
 
     def visit_column_check_constraint(self, constraint):
-        self.append(" ")
         self.append(" CHECK (%s)" % constraint.sqltext)
 
     def visit_primary_key_constraint(self, constraint):
@@ -868,7 +868,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
             return
         self.append(", \n\t")
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
+            self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
         self.append("PRIMARY KEY ")
         self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
 
@@ -884,12 +884,14 @@ class ANSISchemaGenerator(ANSISchemaBase):
         self.execute()
 
     def define_foreign_key(self, constraint):
+        preparer = self.preparer
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
+            self.append("CONSTRAINT %s " %
+                        preparer.format_constraint(constraint))
         self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
-            string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '),
-            self.preparer.format_table(list(constraint.elements)[0].column.table),
-            string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ')
+            string.join([preparer.format_column(f.parent) for f in constraint.elements], ', '),
+            preparer.format_table(list(constraint.elements)[0].column.table),
+            string.join([preparer.format_column(f.column) for f in constraint.elements], ', ')
         ))
         if constraint.ondelete is not None:
             self.append(" ON DELETE %s" % constraint.ondelete)
@@ -899,20 +901,22 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_unique_constraint(self, constraint):
         self.append(", \n\t")
         if constraint.name is not None:
-            self.append("CONSTRAINT %s " % constraint.name)
-        self.append(" UNIQUE ")
-        self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+            self.append("CONSTRAINT %s " %
+                        self.preparer.format_constraint(constraint))
+        self.append(" UNIQUE (%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
 
     def visit_column(self, column):
         pass
 
     def visit_index(self, index):
+        preparer = self.preparer        
         self.append('CREATE ')
         if index.unique:
             self.append('UNIQUE ')
         self.append('INDEX %s ON %s (%s)' \
-                    % (index.name, self.preparer.format_table(index.table),
-                       string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
+                    % (preparer.format_index(index),
+                       preparer.format_table(index.table),
+                       string.join([preparer.format_column(c) for c in index.columns], ', ')))
         self.execute()
 
 class ANSISchemaDropper(ANSISchemaBase):
@@ -932,11 +936,13 @@ class ANSISchemaDropper(ANSISchemaBase):
             table.accept_visitor(self)
 
     def visit_index(self, index):
-        self.append("\nDROP INDEX " + index.name)
+        self.append("\nDROP INDEX " + self.preparer.format_index(index))
         self.execute()
 
     def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name))
+        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
+            self.preparer.format_table(constraint.table),
+            self.preparer.format_constraint(constraint)))
         self.execute()
 
     def visit_table(self, table):
@@ -982,7 +988,7 @@ class ANSIIdentifierPreparer(object):
 
         return value.replace('"', '""')
 
-    def _quote_identifier(self, value):
+    def quote_identifier(self, value):
         """Quote an identifier.
 
         Subclasses should override this to provide database-dependent
@@ -1022,20 +1028,20 @@ class ANSIIdentifierPreparer(object):
 
     def __generic_obj_format(self, obj, ident):
         if getattr(obj, 'quote', False):
-            return self._quote_identifier(ident)
+            return self.quote_identifier(ident)
         if self.dialect.cache_identifiers:
             case_sens = getattr(obj, 'case_sensitive', None)
             try:
                 return self.__strings[(ident, case_sens)]
             except KeyError:
                 if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
-                    self.__strings[(ident, case_sens)] = self._quote_identifier(ident)
+                    self.__strings[(ident, case_sens)] = self.quote_identifier(ident)
                 else:
                     self.__strings[(ident, case_sens)] = ident
                 return self.__strings[(ident, case_sens)]
         else:
             if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
-                return self._quote_identifier(ident)
+                return self.quote_identifier(ident)
             else:
                 return ident
 
@@ -1054,6 +1060,12 @@ class ANSIIdentifierPreparer(object):
     def format_alias(self, alias):
         return self.__generic_obj_format(alias, alias.name)
 
+    def format_constraint(self, constraint):
+        return self.__generic_obj_format(constraint, constraint.name)
+
+    def format_index(self, index):
+        return self.__generic_obj_format(index, index.name)
+
     def format_table(self, table, use_schema=True, name=None):
         """Prepare a quoted table and schema name."""
 
index 2c42ad3ce8a9af0f555969d4b371a8850d252e2e..f10eae7aeb718be6431fe7681a26c41de9b7a83a 100644 (file)
@@ -1038,10 +1038,13 @@ class MySQLDialect(ansisql.ANSIDialect):
     def is_disconnect(self, e):
         return isinstance(e, self.dbapi.OperationalError) and e.args[0] in (2006, 2013, 2014, 2045, 2055)
 
-    def get_default_schema_name(self):
-        if not hasattr(self, '_default_schema_name'):
-            self._default_schema_name = sql.text("select database()", self).scalar()
-        return self._default_schema_name
+    def get_default_schema_name(self, connection):
+        try:
+            return self._default_schema_name
+        except AttributeError:
+            name = self._default_schema_name = \
+              connection.execute('SELECT DATABASE()').scalar()
+            return name
 
     def has_table(self, connection, table_name, schema=None):
         # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
@@ -1054,7 +1057,10 @@ class MySQLDialect(ansisql.ANSIDialect):
         else:
             st = "DESCRIBE `%s`" % table_name
         try:
-            return connection.execute(st).rowcount > 0
+            rs = connection.execute(st)
+            have = rs.rowcount > 0
+            rs.close()
+            return have
         except exceptions.SQLError, e:
             if e.orig.args[0] == 1146:
                 return False
@@ -1286,11 +1292,15 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
 
 class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_index(self, index):
-        self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
+        self.append("\nDROP INDEX %s ON %s" %
+                    (self.preparer.format_index(index),
+                     self.preparer.format_table(index.table)))
         self.execute()
 
     def drop_foreignkey(self, constraint):
-        self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name))
+        self.append("ALTER TABLE %s DROP FOREIGN KEY %s" %
+                    (self.preparer.format_table(constraint.table),
+                     self.preparer.format_constraint(constraint)))
         self.execute()
 
 class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
@@ -1301,8 +1311,7 @@ class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         return RESERVED_WORDS
 
     def _escape_identifier(self, value):
-        #TODO: determine MySQL's escaping rules
-        return value
+        return value.replace('`', '``')
 
     def _fold_identifier_case(self, value):
         #TODO: determine MySQL's case folding rules
index 5c4a38b5d7c7df46c506078057efa8bade5e6385..816b1b76a9abfe5204bdc7865dac7788a5e2536d 100644 (file)
@@ -224,7 +224,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
         return "oid"
 
     def has_table(self, connection, table_name, schema=None):
-        cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {})
+        cursor = connection.execute("PRAGMA table_info(%s)" %
+           self.identifier_preparer.quote_identifier(table_name), {})
         row = cursor.fetchone()
 
         # consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884
index cb12560ea06941e51fabb2633d227d44e4952f9e..666f3b3266b7f361615e0deaa3a97f395f266351 100644 (file)
@@ -471,6 +471,41 @@ class ReflectionTest(PersistTest):
 
     def testreserved(self):
         # check a table that uses an SQL reserved name doesn't cause an error
+        meta = MetaData(testbase.db)
+        table_a = Table('select', meta, 
+                       Column('not', Integer, primary_key=True),
+                       Column('from', String(12), nullable=False),
+                       UniqueConstraint('from', name='when'))
+        Index('where', table_a.c['from'])
+
+        quoter = meta.bind.dialect.identifier_preparer.quote_identifier
+
+        table_b = Table('false', meta,
+                        Column('create', Integer, primary_key=True),
+                        Column('true', Integer, ForeignKey('select.not')),
+                        CheckConstraint('%s <> 1' % quoter('true'), name='limit'))
+
+        table_c = Table('is', meta,
+                        Column('or', Integer, nullable=False, primary_key=True),
+                        Column('join', Integer, nullable=False, primary_key=True),
+                        PrimaryKeyConstraint('or', 'join', name='to'))
+
+        index_c = Index('else', table_c.c.join)
+
+        #meta.bind.echo = True
+        meta.create_all()
+
+        index_c.drop()
+        
+        meta2 = MetaData(testbase.db)
+        try:
+            table_a2 = Table('select', meta2, autoload=True)
+            table_b2 = Table('false', meta2, autoload=True)
+            table_c2 = Table('is', meta2, autoload=True)
+        finally:
+            meta.drop_all()
+
+
         meta = MetaData(testbase.db)
         table = Table(
             'select', meta, 
@@ -575,7 +610,9 @@ class SchemaTest(PersistTest):
 
     @testbase.unsupported('sqlite')
     def testcreate(self):
-        schema = testbase.db.url.database
+        engine = testbase.db
+        schema = engine.dialect.get_default_schema_name(engine)
+
         metadata = MetaData(testbase.db)
         table1 = Table('table1', metadata, 
             Column('col1', Integer, primary_key=True),