]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged r4868, disallow overly long names from create/drop, from 0.4 branch, [ticket...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Jun 2008 16:56:16 +0000 (16:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Jun 2008 16:56:16 +0000 (16:56 +0000)
CHANGES
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/compiler.py
test/sql/labels.py

diff --git a/CHANGES b/CHANGES
index 3edfce6738e92ed73d2fef0d2161ce5a73fd6309..ce1e8ddbc3d1f31afd927e57790626f37d53044a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -101,11 +101,22 @@ CHANGES
       a transaction is in progress [ticket:976].  This
       flag is always True with a "transactional" 
       (in 0.5 a non-"autocommit") Session.
-    
+
+- schema
+    - create_all(), drop_all(), create(), drop() all raise
+      an error if the table name or schema name contains
+      more characters than that dialect's configured
+      character limit.  Some DB's can handle too-long
+      table names during usage, and SQLA can handle this
+      as well. But various reflection/
+      checkfirst-during-create scenarios fail since we are
+      looking for the name within the DB's catalog tables.
+      [ticket:571]
+
 - postgres
     - Repaired server_side_cursors to properly detect 
       text() clauses.
-      
+
 - mysql
     - Added 'CALL' to the list of SQL keywords which return
       result rows.
index b8578151bfb0279f23dcf5d3bf036d165f4ffb11..dcbf8c76fc4b0fe11b0f92ddc535cc0d53a5ffb8 100644 (file)
@@ -15,7 +15,7 @@ as the base class for their own corresponding classes.
 import re, random
 from sqlalchemy.engine import base
 from sqlalchemy.sql import compiler, expression
-
+from sqlalchemy import exc
 
 AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
                                re.I | re.UNICODE)
@@ -70,7 +70,10 @@ class DefaultDialect(base.Dialect):
             typeobj = typeobj()
         return typeobj
 
-
+    def validate_identifier(self, ident):
+        if len(ident) > self.max_identifier_length:
+            raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length))
+        
     def oid_column_name(self, column):
         return None
 
index 71b46ca114590e93809876aa482dca1d62adbec7..e0eb7d88c480f06ff5ae3d4b437f17d6838de438 100644 (file)
@@ -32,6 +32,8 @@ class CircularDependencyError(SQLAlchemyError):
 class CompileError(SQLAlchemyError):
     """Raised when an error occurs during SQL compilation"""
 
+class IdentifierError(SQLAlchemyError):
+    """Raised when a schema name is beyond the max character limit"""
 
 # Moved to orm.exc; compatability definition installed by orm import until 0.6
 ConcurrentModificationError = None
index c5cd0640dc3a7a7a25d033eb97ef8ffb26a2d97e..51c57ca30297a7be21bebdf01957e24308c05a95 100644 (file)
@@ -1391,8 +1391,7 @@ class Query(object):
             statement.append_from(from_clause)
 
             if context.order_by:
-                local_adapter = sql_util.ClauseAdapter(inner)
-                statement.append_order_by(*local_adapter.copy_and_process(context.order_by))
+                statement.append_order_by(*context.adapter.copy_and_process(context.order_by))
 
             statement.append_order_by(*context.eager_order_by)
         else:
@@ -1580,7 +1579,14 @@ class _MapperEntity(_QueryEntity):
         for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic):
             if query._only_load_props and value.key not in query._only_load_props:
                 continue
-            value.setup(context, self, (self.path_entity,), adapter, only_load_props=query._only_load_props, column_collection=context.primary_columns)
+            value.setup(
+                context, 
+                self, 
+                (self.path_entity,), 
+                adapter, 
+                only_load_props=query._only_load_props, 
+                column_collection=context.primary_columns
+            )
 
     def __str__(self):
         return str(self.mapper)
@@ -1610,7 +1616,11 @@ class _ColumnEntity(_QueryEntity):
         self.column = column
         self.entity_name = None
         self.froms = util.Set()
-        self.entities = util.OrderedSet([elem._annotations['parententity'] for elem in visitors.iterate(column, {}) if 'parententity' in elem._annotations])
+        self.entities = util.OrderedSet([
+            elem._annotations['parententity'] for elem in visitors.iterate(column, {}) 
+            if 'parententity' in elem._annotations
+        ])
+        
         if self.entities:
             self.entity_zero = list(self.entities)[0]
         else:
@@ -1620,11 +1630,11 @@ class _ColumnEntity(_QueryEntity):
         self.selectable = from_obj
         self.froms.add(from_obj)
 
-    def __resolve_expr_against_query_aliases(self, query, expr, context):
+    def _resolve_expr_against_query_aliases(self, query, expr, context):
         return query._adapt_clause(expr, False, True)
 
     def row_processor(self, query, context, custom_rows):
-        column = self.__resolve_expr_against_query_aliases(query, self.column, context)
+        column = self._resolve_expr_against_query_aliases(query, self.column, context)
 
         if context.adapter:
             column = context.adapter.columns[column]
@@ -1635,7 +1645,7 @@ class _ColumnEntity(_QueryEntity):
         return (proc, getattr(column, 'name', None))
 
     def setup_context(self, query, context):
-        column = self.__resolve_expr_against_query_aliases(query, self.column, context)
+        column = self._resolve_expr_against_query_aliases(query, self.column, context)
         context.froms += list(self.froms)
         context.primary_columns.append(column)
 
index 66e9ccd973d9a65b3b2384394ccc115c4715feff..fcb56865b76ff49c4c6f1cf2c8c79d0fa8979995 100644 (file)
@@ -217,7 +217,11 @@ class LoadDeferredColumns(object):
         self.keys = keys
 
     def __getstate__(self):
-        return {'state':self.state, 'key':self.key, 'keys':self.keys}
+        return {
+            'state':self.state, 
+            'key':self.key, 
+            'keys':self.keys
+        }
     
     def __setstate__(self, state):
         self.state = state['state']
@@ -330,7 +334,7 @@ NoLoader.logger = log.class_logger(NoLoader)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property)
+        (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self._create_lazy_clause(self.parent_property)
         
         self.logger.info("%s lazy loading clause %s" % (self, self.__lazywhere))
 
@@ -352,7 +356,7 @@ class LazyLoader(AbstractRelationLoader):
         if not reverse_direction:
             (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
         else:
-            (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+            (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
         def visit_bindparam(bindparam):
             mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
@@ -371,7 +375,7 @@ class LazyLoader(AbstractRelationLoader):
         if not reverse_direction:
             (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
         else:
-            (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+            (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
         def visit_binary(binary):
             mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
@@ -434,7 +438,7 @@ class LazyLoader(AbstractRelationLoader):
 
             return (new_execute, None)
 
-    def __create_lazy_clause(cls, prop, reverse_direction=False):
+    def _create_lazy_clause(cls, prop, reverse_direction=False):
         binds = {}
         lookup = {}
         equated_columns = {}
@@ -474,7 +478,7 @@ class LazyLoader(AbstractRelationLoader):
         bind_to_col = dict([(binds[col].key, col) for col in binds])
         
         return (lazywhere, bind_to_col, equated_columns)
-    __create_lazy_clause = classmethod(__create_lazy_clause)
+    _create_lazy_clause = classmethod(_create_lazy_clause)
     
 LazyLoader.logger = log.class_logger(LazyLoader)
 
@@ -488,7 +492,12 @@ class LoadLazyAttribute(object):
         self.path = path
         
     def __getstate__(self):
-        return {'state':self.state, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+        return {
+            'state':self.state, 
+            'key':self.key, 
+            'options':self.options, 
+            'path':serialize_path(self.path)
+        }
 
     def __setstate__(self, state):
         self.state = state['state']
@@ -510,7 +519,11 @@ class LoadLazyAttribute(object):
 
         session = sessionlib._state_session(state)
         if session is None:
-            raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key))
+            raise sa_exc.UnboundExecutionError(
+                "Parent instance %s is not bound to a Session; "
+                "lazy load operation of attribute '%s' cannot proceed" % 
+                (mapperutil.state_str(state), self.key)
+            )
 
         q = session.query(prop.mapper).autoflush(False)._adapt_all_clauses()
         
@@ -547,7 +560,6 @@ class LoadLazyAttribute(object):
                 return result[0]
             else:
                 return None
-        
 
 class EagerLoader(AbstractRelationLoader):
     """Loads related objects inline with a parent query."""
@@ -576,8 +588,7 @@ class EagerLoader(AbstractRelationLoader):
                 context.attributes[("eager_row_processor", path)] = clauses = adapter
                 
         else:
-        
-            clauses = self.__create_eager_join(context, entity, path, adapter, parentmapper)
+            clauses = self._create_eager_join(context, entity, path, adapter, parentmapper)
             if not clauses:
                 return
 
@@ -586,7 +597,7 @@ class EagerLoader(AbstractRelationLoader):
         for value in self.mapper._iterate_polymorphic_properties():
             value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns)
     
-    def __create_eager_join(self, context, entity, path, adapter, parentmapper):
+    def _create_eager_join(self, context, entity, path, adapter, parentmapper):
         # check for join_depth or basic recursion,
         # if the current path was not explicitly stated as 
         # a desired "loaderstrategy" (i.e. via query.options())
@@ -662,7 +673,7 @@ class EagerLoader(AbstractRelationLoader):
             
         return clauses
         
-    def __create_eager_adapter(self, context, row, adapter, path):
+    def _create_eager_adapter(self, context, row, adapter, path):
         if ("eager_row_processor", path) in context.attributes:
             decorator = context.attributes[("eager_row_processor", path)]
         else:
@@ -682,7 +693,7 @@ class EagerLoader(AbstractRelationLoader):
     def create_row_processor(self, context, path, mapper, row, adapter):
         path = path + (self.key,)
             
-        eager_adapter = self.__create_eager_adapter(context, row, adapter, path)
+        eager_adapter = self._create_eager_adapter(context, row, adapter, path)
         
         if eager_adapter is not False:
             key = self.key
index 8c8374b9a1b54b18710293a72fa6dffbd01baf52..b57fd3b1853a440899cd2b8a769a50d5b24cece0 100644 (file)
@@ -753,8 +753,14 @@ class SchemaGenerator(DDLBase):
     def get_column_specification(self, column, first_pk=False):
         raise NotImplementedError()
 
+    def _can_create(self, table):
+        self.dialect.validate_identifier(table.name)
+        if table.schema:
+            self.dialect.validate_identifier(table.schema)
+        return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
     def visit_metadata(self, metadata):
-        collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
+        collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if self._can_create(t)]
         for table in collection:
             self.traverse_single(table)
         if self.dialect.supports_alter:
@@ -910,13 +916,19 @@ class SchemaDropper(DDLBase):
         self.dialect = dialect
 
     def visit_metadata(self, metadata):
-        collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or  self.dialect.has_table(self.connection, t.name, schema=t.schema))]
+        collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if self._can_drop(t)]
         if self.dialect.supports_alter:
             for alterable in self.find_alterables(collection):
                 self.drop_foreignkey(alterable)
         for table in collection:
             self.traverse_single(table)
 
+    def _can_drop(self, table):
+        self.dialect.validate_identifier(table.name)
+        if table.schema:
+            self.dialect.validate_identifier(table.schema)
+        return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
     def visit_index(self, index):
         self.append("\nDROP INDEX " + self.preparer.format_index(index))
         self.execute()
index 78b31adc4aab74586fa82eed83df3191175e5f62..3e025e5e72d77a3acdfd20be674dd883dd1f247b 100644 (file)
@@ -1,5 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
+from sqlalchemy import exc as exceptions
 from testlib import *
 from sqlalchemy.engine import default
 
@@ -38,6 +39,14 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
         metadata.drop_all()
         testing.db.dialect.max_identifier_length = maxlen
 
+    def test_too_long_name_disallowed(self):
+        m = MetaData(testing.db)
+        t1 = Table("this_name_is_too_long_for_what_were_doing_in_this_test", m, Column('foo', Integer))
+        self.assertRaises(exceptions.IdentifierError, m.create_all)
+        self.assertRaises(exceptions.IdentifierError, m.drop_all)
+        self.assertRaises(exceptions.IdentifierError, t1.create)
+        self.assertRaises(exceptions.IdentifierError, t1.drop)
+        
     def test_result(self):
         table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
         table1.insert().execute(**{"this_is_the_primarykey_column":2, "this_is_the_data_column":"data2"})