self.default = kwargs.pop('default', None)
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
+ self.onupdate = kwargs.pop('onupdate', None)
if self.index is not None and self.unique is not None:
raise ArgumentError("Column may not define both index and unique")
self._foreign_key = None
return "Column(%s)" % string.join(
[repr(self.name)] + [repr(self.type)] +
[repr(x) for x in [self.foreign_key] if x is not None] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default']]
+ ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']]
, ',')
def append_item(self, item):
if self.default is not None:
self.default = ColumnDefault(self.default)
self._init_items(self.default)
+ if self.onupdate is not None:
+ self.onupdate = ColumnDefault(self.onupdate, for_update=True)
+ self._init_items(self.onupdate)
self._init_items(*self.args)
self.args = None
class DefaultGenerator(SchemaItem):
"""Base class for column "default" values."""
+ def __init__(self, for_update=False, engine=None):
+ self.for_update = for_update
+ self.engine = engine
def _set_parent(self, column):
self.column = column
- self.column.default = self
+ if self.engine is None:
+ self.engine = column.table.engine
+ if self.for_update:
+ self.column.onupdate = self
+ else:
+ self.column.default = self
def execute(self):
- return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute))
+ return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.proxy))
def __repr__(self):
return "DefaultGenerator()"
class PassiveDefault(DefaultGenerator):
"""a default that takes effect on the database side"""
- def __init__(self, arg):
+ def __init__(self, arg, **kwargs):
+ super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
def accept_schema_visitor(self, visitor):
return visitor.visit_passive_default(self)
class ColumnDefault(DefaultGenerator):
"""A plain default value on a column. this could correspond to a constant,
a callable function, or a SQL clause."""
- def __init__(self, arg):
+ def __init__(self, arg, **kwargs):
+ super(ColumnDefault, self).__init__(**kwargs)
self.arg = arg
def accept_schema_visitor(self, visitor):
"""calls the visit_column_default method on the given visitor."""
class Sequence(DefaultGenerator):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, start = None, increment = None, optional=False, engine=None):
+ def __init__(self, name, start = None, increment = None, optional=False, **kwargs):
+ super(Sequence, self).__init__(**kwargs)
self.name = name
self.start = start
self.increment = increment
self.optional=optional
- self.engine = engine
def __repr__(self):
return "Sequence(%s)" % string.join(
[repr(self.name)] +
def _set_parent(self, column):
super(Sequence, self)._set_parent(column)
column.sequence = self
- if self.engine is None:
- self.engine = column.table.engine
def create(self):
self.engine.create(self)
return self
class DefaultTest(PersistTest):
- def testdefaults(self):
+ def setUpAll(self):
+ global t, f, ts
x = {'x':50}
def mydefault():
x['x'] += 1
Column('col5', deftype, PassiveDefault(def2))
)
t.create()
- try:
- t.insert().execute()
- self.assert_(t.engine.lastrow_has_defaults())
- t.insert().execute()
- t.insert().execute()
+
+ def teststandalonedefaults(self):
+ x = t.c.col1.default.execute()
+ y = t.c.col2.default.execute()
+ z = t.c.col3.default.execute()
+ self.assert_(50 <= x <= 57)
+ self.assert_(y == 'imthedefault')
+ self.assert_(z == 6)
- l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
- finally:
- t.drop()
+ def testinsertdefaults(self):
+ t.insert().execute()
+ self.assert_(t.engine.lastrow_has_defaults())
+ t.insert().execute()
+ t.insert().execute()
+
+ l = t.select().execute()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
+
+ def tearDownAll(self):
+ t.drop()
class SequenceTest(PersistTest):