]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
working on sequence quoting support....
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Aug 2006 04:38:51 +0000 (04:38 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Aug 2006 04:38:51 +0000 (04:38 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/schema.py

index 53c6db6c47e91cd342ed8b7a7d93e8da0f1f8b5c..f77e855e442364d7b2b221407d61e2d55cbfb87d 100644 (file)
@@ -774,7 +774,15 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
             self.__strings[column] = self._quote_identifier(column.name)
         else:
             self.__strings[column] = column.name
-        
+    
+    def visit_sequence(self, sequence):
+        if sequence in self.__visited:
+            return
+        if sequence.quote or self._requires_quotes(sequence.name, sequence.natural_case):
+            self.__strings[sequence] = self._quote_identifier(sequence.name)
+        else:
+            self.__strings[sequence] = sequence.name
+                
     def __analyze_identifiers(self, obj):
         """insure that each object we encounter is analyzed only once for its lifetime."""
         if obj in self.__visited:
@@ -782,7 +790,11 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
         if isinstance(obj, schema.SchemaItem):
             obj.accept_schema_visitor(self)
         self.__visited[obj] = True
-         
+    
+    def __prepare_sequence(self, sequence):
+        self.__analyze_identifiers(sequence)
+        return self.__strings.get(sequence, sequence.name)
+             
     def __prepare_table(self, table, use_schema=False):
         self.__analyze_identifiers(table)
         tablename = self.__strings.get(table, (table.name, None))[0]
@@ -798,6 +810,9 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
         else:
             return self.__strings.get(column, column.name)
     
+    def format_sequence(self, sequence):
+        return self.__prepare_sequence(sequence)
+        
     def format_table(self, table, use_schema=True):
         """Prepare a quoted table and schema name"""
         return self.__prepare_table(table, use_schema=use_schema)
index 887c593e160e88f122990621cafd2fab0ceca03a..1bcf83409b8e2dd37af997d81d7821f06ed69241 100644 (file)
@@ -530,6 +530,8 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
                 return c.fetchone()[0]
             elif isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
                 sch = column.table.schema
+                # TODO: this has to build into the Sequence object so we can get the quoting 
+                # logic from it
                 if sch is not None:
                     exc = "select nextval('%s.%s_%s_seq')" % (sch, column.table.name, column.name)
                 else:
@@ -543,7 +545,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
         
     def visit_sequence(self, seq):
         if not seq.optional:
-            c = self.proxy("select nextval('%s')" % seq.name)
+            c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
             return c.fetchone()[0]
         else:
             return None
index 2bf1627dd076e392c18f47b9a393d0258e3568e3..1693084452169d6ab06cab767037884745a0ac2f 100644 (file)
@@ -169,8 +169,11 @@ class Table(SchemaItem, sql.TableClause):
         self.owner = kwargs.pop('owner', None)
         self.quote = kwargs.pop('quote', False)
         self.quote_schema = kwargs.pop('quote_schema', False)
-        self.natural_case = kwargs.pop('natural_case', True)
-        self.natural_case_schema = kwargs.pop('natural_case_schema', True)
+        default_natural_case = metadata.natural_case
+        if default_natural_case is None:
+            default_natural_case = True
+        self.natural_case = kwargs.pop('natural_case', default_natural_case)
+        self.natural_case_schema = kwargs.pop('natural_case_schema', default_natural_case)
         self.kwargs = kwargs
 
     def _set_primary_key(self, pk):
@@ -403,6 +406,8 @@ class Column(SchemaItem, sql.ColumnClause):
         if getattr(self, 'table', None) is not None:
             raise exceptions.ArgumentError("this Column already has a table!")
         table.append_column(self)
+        if self.table.metadata.natural_case is not None:
+            self.natural_case = self.table.metadata.natural_case
         if self.index or self.unique:
             table.append_index_column(self, index=self.index,
                                       unique=self.unique)
@@ -595,12 +600,14 @@ class ColumnDefault(DefaultGenerator):
         
 class Sequence(DefaultGenerator):
     """represents a sequence, which applies to Oracle and Postgres databases."""
-    def __init__(self, name, start = None, increment = None, optional=False, **kwargs):
+    def __init__(self, name, start = None, increment = None, optional=False, quote=False, natural_case=True, **kwargs):
         super(Sequence, self).__init__(**kwargs)
         self.name = name
         self.start = start
         self.increment = increment
         self.optional=optional
+        self.natural_case = natural_case
+        self.quote = quote
     def __repr__(self):
         return "Sequence(%s)" % string.join(
              [repr(self.name)] +
@@ -609,6 +616,8 @@ class Sequence(DefaultGenerator):
     def _set_parent(self, column):
         super(Sequence, self)._set_parent(column)
         column.sequence = self
+        if column.metadata.natural_case is not None:
+            self.natural_case = column.metadata.natural_case
     def create(self):
        self.engine.create(self)
        return self
@@ -763,10 +772,11 @@ class Index(SchemaItem):
         
 class MetaData(SchemaItem):
     """represents a collection of Tables and their associated schema constructs."""
-    def __init__(self, name=None):
+    def __init__(self, name=None, natural_case=None, **kwargs):
         # a dictionary that stores Table objects keyed off their name (and possibly schema name)
         self.tables = {}
         self.name = name
+        self.natural_case = natural_case
     def is_bound(self):
         return False
     def clear(self):
@@ -850,7 +860,7 @@ class MetaData(SchemaItem):
 class BoundMetaData(MetaData):
     """builds upon MetaData to provide the capability to bind to an Engine implementation."""
     def __init__(self, engine_or_url, name=None, **kwargs):
-        super(BoundMetaData, self).__init__(name)
+        super(BoundMetaData, self).__init__(name, **kwargs)
         if isinstance(engine_or_url, str):
             self._engine = sqlalchemy.create_engine(engine_or_url, **kwargs)
         else:
@@ -861,8 +871,8 @@ class BoundMetaData(MetaData):
 class DynamicMetaData(MetaData):
     """builds upon MetaData to provide the capability to bind to multiple Engine implementations
     on a dynamically alterable, thread-local basis."""
-    def __init__(self, name=None, threadlocal=True):
-        super(DynamicMetaData, self).__init__(name)
+    def __init__(self, name=None, threadlocal=True, **kwargs):
+        super(DynamicMetaData, self).__init__(name, **kwargs)
         if threadlocal:
             self.context = util.ThreadLocal()
         else: