]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
rethinking sequences model to allow any default values
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Dec 2005 04:25:59 +0000 (04:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Dec 2005 04:25:59 +0000 (04:25 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/schema.py

index cad60b969779ebc1870080d569a63fc654c73ba2..977b4c427e3f1d431226acf743f70411aff7d784 100644 (file)
@@ -241,6 +241,41 @@ class SQLEngine(schema.SchemaEngine):
                 self.do_commit(self.context.transaction)
                 self.context.transaction = None
                 self.context.tcount = None
+
+
+    def _process_sequences(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+        if compiled is None: return
+        if getattr(compiled, "isinsert", False):
+            if isinstance(parameters, list):
+                plist = parameters
+            else:
+                plist = [parameters]
+            # inserts are usually one at a time.  but if we got a list of parameters,
+            # it will calculate last_inserted_ids for just the last row in the list. 
+            # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
+            # it or post-select anyway   
+            for param in plist:
+                last_inserted_ids = []
+                need_lastrowid=False
+                for c in compiled.statement.table.c:
+                    if not param.has_key(c.key) or param[c.key] is None:
+                        if c.sequence is not None:
+                            newid = self.exec_sequence(c.sequence)
+                        else:
+                            newid = None
+                            
+                        if newid is not None:
+                            param[c.key] = newid
+                            if c.primary_key:
+                                last_inserted_ids.append(param[c.key])
+                        elif c.primary_key:
+                            need_lastrowid = True
+                    elif c.primary_key:
+                        last_inserted_ids.append(param[c.key])
+                if need_lastrowid:
+                    self.context.last_inserted_ids = None
+                else:
+                    self.context.last_inserted_ids = last_inserted_ids
             
     def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
         pass
@@ -287,7 +322,8 @@ class SQLEngine(schema.SchemaEngine):
 
         try:
             self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs)
-
+            #self._process_sequences(connection, c, statement, parameters, echo = echo, **kwargs)
+            
             if echo is True or self.echo is not False:
                 self.log(statement)
                 self.log(repr(parameters))
index 5fdf0e35b0f50e2b25598a1f79e4329e58fb072a..cfa373e0aea43493b8258cc7a2b4df13ceaeb456 100644 (file)
@@ -459,6 +459,7 @@ class Mapper(object):
                 if hasattr(obj, "_instance_key"):
                     update.append(params)
                 else:
+                    self.extension.before_insert(self, obj)
                     insert.append((obj, params))
                 uow.register_saved_object(obj)
             if len(update):
@@ -476,6 +477,11 @@ class Mapper(object):
                 for rec in insert:
                     (obj, params) = rec
                     statement.execute(**params)
+                    # TODO: the engine is going to now do defaults for non-primarykey columns as well.
+                    # have the engine store a dictionary of all column/generated values and set them
+                    # all up.  also might want to have the last_inserted_ids that does a select actually
+                    # go and get everything that was generated by DB-level defaults, not just primary key 
+                    # columns.
                     primary_key = table.engine.last_inserted_ids()
                     if primary_key is not None:
                         i = 0
@@ -676,8 +682,8 @@ class ExtensionOption(MapperOption):
     def __init__(self, ext):
         self.ext = ext
     def process(self, mapper):
-        ext.next = mapper.extension
-        mapper.extension = ext
+        self.ext.next = mapper.extension
+        mapper.extension = self.ext
 
 class MapperExtension(object):
     def __init__(self):
@@ -692,6 +698,9 @@ class MapperExtension(object):
             return True
         else:
             return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing)
+    def before_insert(self, mapper, instance):
+        if self.next is not None:
+            self.next.before_insert(mapper, instance)
     def after_insert(self, mapper, instance):
         if self.next is not None:
             self.next.after_insert(mapper, instance)
index 789bb6ed8c03054cc14e7cc39de919e29787e95d..60c42c25ad8e577bccc5294b17a42a4562eb7bb7 100644 (file)
@@ -288,8 +288,9 @@ class ForeignKey(SchemaItem):
         
 class Sequence(SchemaItem):
     """represents a sequence, which applies to Oracle and Postgres databases."""
-    def __init__(self, name, start = None, increment = None, optional=False):
+    def __init__(self, name, func = None, start = None, increment = None, optional=False):
         self.name = name
+        self.func = func
         self.start = start
         self.increment = increment
         self.optional=optional