class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if isinstance(column.default, schema.PassiveDefault):
+ colspec += " DEFAULT " + column.default.text
+ elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
else:
return None
+ def visit_passive_default(self, default):
+ """passive defaults by definition return None on the app side,
+ and are post-fetched to get the DB-side value"""
+ return None
+
def visit_sequence(self, seq):
"""sequences are not supported by default"""
return None
else:
plist = [parameters]
drunner = self.defaultrunner(proxy)
+ self.context.lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
need_lastrowid=False
for c in compiled.statement.table.c:
+ if isinstance(c.default, schema.PassiveDefault):
+ self.context.lastrow_has_defaults = True
if not param.has_key(c.key) or param[c.key] is None:
newid = drunner.get_column_default(c)
if newid is not None:
else:
self.context.last_inserted_ids = last_inserted_ids
-
+ def lastrow_has_defaults(self):
+ return self.context.lastrow_has_defaults
+
def pre_exec(self, proxy, compiled, parameters, **kwargs):
"""called by execute_compiled before the compiled statement is executed."""
pass
if self._getattrbycolumn(obj, col) is None:
self._setattrbycolumn(obj, col, primary_key[i])
i+=1
+ if table.engine.lastrow_has_defaults():
+ clause = sql.and_()
+ for p in self.pks_by_table[table]:
+ clause.clauses.append(p == self._getattrbycolumn(obj, p))
+ row = table.select(clause).execute().fetchone()
+ for c in table.c:
+ if self._getattrbycolumn(obj, col) is None:
+ self._setattrbycolumn(obj, col, row[c])
self.extension.after_insert(self, obj)
def delete_obj(self, objects, uow):
self.column.default = self
def __repr__(self):
return "DefaultGenerator()"
+
+class PassiveDefault(DefaultGenerator):
+ """a default that takes effect on the database side"""
+ def __init__(self, text):
+ self.text = text
+ def accept_visitor(self, visitor):
+ return visitor_visit_passive_default(self)
+ def __repr__(self):
+ return "PassiveDefault(%s)" % repr(self.text)
class ColumnDefault(DefaultGenerator):
"""A plain default value on a column. this could correspond to a constant,
def visit_index(self, index):
"""visit an Index (not implemented yet)."""
pass
+ def visit_passive_default(self, default):
+ """visit a passive default"""
+ pass
def visit_column_default(self, default):
"""visit a ColumnDefault."""
pass