From: Mike Bayer Date: Sun, 5 Feb 2006 00:19:14 +0000 (+0000) Subject: started PassiveDefault, which is a "database-side" default. mapper will go X-Git-Tag: rel_0_1_0~62 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1d20ecbb6d0f0f8cfbb0a54e2a3aaf6cead23ecb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git started PassiveDefault, which is a "database-side" default. mapper will go fetch the most recently inserted row if a table has PassiveDefault's set on it --- diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 2d6adc165b..9122c2afa1 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -275,7 +275,9 @@ class PGCompiler(ansisql.ANSICompiler): class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if isinstance(column.default, schema.PassiveDefault): + colspec += " DEFAULT " + column.default.text + elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.get_col_spec() diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index d99e5eb6ca..b00d97de0b 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -135,6 +135,11 @@ class DefaultRunner(schema.SchemaVisitor): else: return None + def visit_passive_default(self, default): + """passive defaults by definition return None on the app side, + and are post-fetched to get the DB-side value""" + return None + def visit_sequence(self, seq): """sequences are not supported by default""" return None @@ -452,10 +457,13 @@ class SQLEngine(schema.SchemaEngine): else: plist = [parameters] drunner = self.defaultrunner(proxy) + self.context.lastrow_has_defaults = False for param in plist: last_inserted_ids = [] need_lastrowid=False for c in compiled.statement.table.c: + if isinstance(c.default, schema.PassiveDefault): + self.context.lastrow_has_defaults = True if not param.has_key(c.key) or param[c.key] is None: newid = drunner.get_column_default(c) if newid is not None: @@ -471,7 +479,9 @@ class SQLEngine(schema.SchemaEngine): else: self.context.last_inserted_ids = last_inserted_ids - + def lastrow_has_defaults(self): + return self.context.lastrow_has_defaults + def pre_exec(self, proxy, compiled, parameters, **kwargs): """called by execute_compiled before the compiled statement is executed.""" pass diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 7e11f5ebe9..4516ae7b36 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -578,6 +578,14 @@ class Mapper(object): if self._getattrbycolumn(obj, col) is None: self._setattrbycolumn(obj, col, primary_key[i]) i+=1 + if table.engine.lastrow_has_defaults(): + clause = sql.and_() + for p in self.pks_by_table[table]: + clause.clauses.append(p == self._getattrbycolumn(obj, p)) + row = table.select(clause).execute().fetchone() + for c in table.c: + if self._getattrbycolumn(obj, col) is None: + self._setattrbycolumn(obj, col, row[c]) self.extension.after_insert(self, obj) def delete_obj(self, objects, uow): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index a5e6e0777f..de672dc9ee 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -417,6 +417,15 @@ class DefaultGenerator(SchemaItem): self.column.default = self def __repr__(self): return "DefaultGenerator()" + +class PassiveDefault(DefaultGenerator): + """a default that takes effect on the database side""" + def __init__(self, text): + self.text = text + def accept_visitor(self, visitor): + return visitor_visit_passive_default(self) + def __repr__(self): + return "PassiveDefault(%s)" % repr(self.text) class ColumnDefault(DefaultGenerator): """A plain default value on a column. this could correspond to a constant, @@ -477,6 +486,9 @@ class SchemaVisitor(object): def visit_index(self, index): """visit an Index (not implemented yet).""" pass + def visit_passive_default(self, default): + """visit a passive default""" + pass def visit_column_default(self, default): """visit a ColumnDefault.""" pass