]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
started PassiveDefault, which is a "database-side" default. mapper will go
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Feb 2006 00:19:14 +0000 (00:19 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Feb 2006 00:19:14 +0000 (00:19 +0000)
fetch the most recently inserted row if a table has PassiveDefault's set on it

lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/schema.py

index 2d6adc165bb9005a20453f98d88bd7d8cb033ca2..9122c2afa10827603072098dc5d5ca535d896fcd 100644 (file)
@@ -275,7 +275,9 @@ class PGCompiler(ansisql.ANSICompiler):
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
-        if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+        if isinstance(column.default, schema.PassiveDefault):
+            colspec += " DEFAULT " + column.default.text
+        elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
         else:
             colspec += " " + column.type.get_col_spec()
index d99e5eb6ca41b7b6b5f49ca6f89a6fc51af768a4..b00d97de0ba475a9e58c409f033400e13feed1f6 100644 (file)
@@ -135,6 +135,11 @@ class DefaultRunner(schema.SchemaVisitor):
         else:
             return None
 
+    def visit_passive_default(self, default):
+        """passive defaults by definition return None on the app side,
+        and are post-fetched to get the DB-side value"""
+        return None
+        
     def visit_sequence(self, seq):
         """sequences are not supported by default"""
         return None
@@ -452,10 +457,13 @@ class SQLEngine(schema.SchemaEngine):
             else:
                 plist = [parameters]
             drunner = self.defaultrunner(proxy)
+            self.context.lastrow_has_defaults = False
             for param in plist:
                 last_inserted_ids = []
                 need_lastrowid=False
                 for c in compiled.statement.table.c:
+                    if isinstance(c.default, schema.PassiveDefault):
+                        self.context.lastrow_has_defaults = True
                     if not param.has_key(c.key) or param[c.key] is None:
                         newid = drunner.get_column_default(c)
                         if newid is not None:
@@ -471,7 +479,9 @@ class SQLEngine(schema.SchemaEngine):
                 else:
                     self.context.last_inserted_ids = last_inserted_ids
 
-
+    def lastrow_has_defaults(self):
+        return self.context.lastrow_has_defaults
+        
     def pre_exec(self, proxy, compiled, parameters, **kwargs):
         """called by execute_compiled before the compiled statement is executed."""
         pass
index 7e11f5ebe9f8a51b6bb6614c4cb8d921cef01ddc..4516ae7b365b4ef86727d08289029284c4be323e 100644 (file)
@@ -578,6 +578,14 @@ class Mapper(object):
                             if self._getattrbycolumn(obj, col) is None:
                                 self._setattrbycolumn(obj, col, primary_key[i])
                             i+=1
+                    if table.engine.lastrow_has_defaults():
+                        clause = sql.and_()
+                        for p in self.pks_by_table[table]:
+                            clause.clauses.append(p == self._getattrbycolumn(obj, p))
+                        row = table.select(clause).execute().fetchone()
+                        for c in table.c:
+                            if self._getattrbycolumn(obj, col) is None:
+                                self._setattrbycolumn(obj, col, row[c])
                     self.extension.after_insert(self, obj)
                     
     def delete_obj(self, objects, uow):
index a5e6e0777f8b888ba02e1b1f52b282dbf140aad9..de672dc9ee684e357c4a36c18eb4ca2508c6d730 100644 (file)
@@ -417,6 +417,15 @@ class DefaultGenerator(SchemaItem):
         self.column.default = self
     def __repr__(self):
         return "DefaultGenerator()"
+
+class PassiveDefault(DefaultGenerator):
+    """a default that takes effect on the database side"""
+    def __init__(self, text):
+        self.text = text
+    def accept_visitor(self, visitor):
+        return visitor_visit_passive_default(self)
+    def __repr__(self):
+        return "PassiveDefault(%s)" % repr(self.text)
         
 class ColumnDefault(DefaultGenerator):
     """A plain default value on a column.  this could correspond to a constant, 
@@ -477,6 +486,9 @@ class SchemaVisitor(object):
     def visit_index(self, index):
         """visit an Index (not implemented yet)."""
         pass
+    def visit_passive_default(self, default):
+        """visit a passive default"""
+        pass
     def visit_column_default(self, default):
         """visit a ColumnDefault."""
         pass