From: Mike Bayer Date: Sat, 4 Mar 2006 20:23:37 +0000 (+0000) Subject: got column defaults to be executeable X-Git-Tag: rel_0_1_4~43 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c1d0c2dffc0eedfa63de5b90addb70bfd3a81540;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git got column defaults to be executeable --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 0f6b659093..7d158cb7e6 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -596,6 +596,17 @@ class SQLEngine(schema.SchemaEngine): def _executemany(self, c, statement, parameters): c.executemany(statement, parameters) self.context.rowcount = c.rowcount + + def proxy(self, statement=None, parameters=None): + executemany = parameters is not None and isinstance(parameters, list) + + if self.positional: + if executemany: + parameters = [p.values() for p in parameters] + else: + parameters = parameters.values() + + return self.execute(statement, parameters) def log(self, msg): """logs a message using this SQLEngine's logger stream.""" diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 17e421f228..57ae7ba5af 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -283,6 +283,7 @@ class Column(sql.ColumnClause, SchemaItem): self.default = kwargs.pop('default', None) self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) + self.onupdate = kwargs.pop('onupdate', None) if self.index is not None and self.unique is not None: raise ArgumentError("Column may not define both index and unique") self._foreign_key = None @@ -302,7 +303,7 @@ class Column(sql.ColumnClause, SchemaItem): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + [repr(x) for x in [self.foreign_key] if x is not None] + - ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default']] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']] , ',') def append_item(self, item): @@ -326,6 +327,9 @@ class Column(sql.ColumnClause, SchemaItem): if self.default is not None: self.default = ColumnDefault(self.default) self._init_items(self.default) + if self.onupdate is not None: + self.onupdate = ColumnDefault(self.onupdate, for_update=True) + self._init_items(self.onupdate) self._init_items(*self.args) self.args = None @@ -435,17 +439,26 @@ class ForeignKey(SchemaItem): class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" + def __init__(self, for_update=False, engine=None): + self.for_update = for_update + self.engine = engine def _set_parent(self, column): self.column = column - self.column.default = self + if self.engine is None: + self.engine = column.table.engine + if self.for_update: + self.column.onupdate = self + else: + self.column.default = self def execute(self): - return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute)) + return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.proxy)) def __repr__(self): return "DefaultGenerator()" class PassiveDefault(DefaultGenerator): """a default that takes effect on the database side""" - def __init__(self, arg): + def __init__(self, arg, **kwargs): + super(PassiveDefault, self).__init__(**kwargs) self.arg = arg def accept_schema_visitor(self, visitor): return visitor.visit_passive_default(self) @@ -455,7 +468,8 @@ class PassiveDefault(DefaultGenerator): class ColumnDefault(DefaultGenerator): """A plain default value on a column. this could correspond to a constant, a callable function, or a SQL clause.""" - def __init__(self, arg): + def __init__(self, arg, **kwargs): + super(ColumnDefault, self).__init__(**kwargs) self.arg = arg def accept_schema_visitor(self, visitor): """calls the visit_column_default method on the given visitor.""" @@ -465,12 +479,12 @@ class ColumnDefault(DefaultGenerator): class Sequence(DefaultGenerator): """represents a sequence, which applies to Oracle and Postgres databases.""" - def __init__(self, name, start = None, increment = None, optional=False, engine=None): + def __init__(self, name, start = None, increment = None, optional=False, **kwargs): + super(Sequence, self).__init__(**kwargs) self.name = name self.start = start self.increment = increment self.optional=optional - self.engine = engine def __repr__(self): return "Sequence(%s)" % string.join( [repr(self.name)] + @@ -479,8 +493,6 @@ class Sequence(DefaultGenerator): def _set_parent(self, column): super(Sequence, self)._set_parent(column) column.sequence = self - if self.engine is None: - self.engine = column.table.engine def create(self): self.engine.create(self) return self diff --git a/test/defaults.py b/test/defaults.py index fcf852a86a..459b3abfe9 100644 --- a/test/defaults.py +++ b/test/defaults.py @@ -10,7 +10,8 @@ db = testbase.db class DefaultTest(PersistTest): - def testdefaults(self): + def setUpAll(self): + global t, f, ts x = {'x':50} def mydefault(): x['x'] += 1 @@ -56,16 +57,26 @@ class DefaultTest(PersistTest): Column('col5', deftype, PassiveDefault(def2)) ) t.create() - try: - t.insert().execute() - self.assert_(t.engine.lastrow_has_defaults()) - t.insert().execute() - t.insert().execute() + + def teststandalonedefaults(self): + x = t.c.col1.default.execute() + y = t.c.col2.default.execute() + z = t.c.col3.default.execute() + self.assert_(50 <= x <= 57) + self.assert_(y == 'imthedefault') + self.assert_(z == 6) - l = t.select().execute() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)]) - finally: - t.drop() + def testinsertdefaults(self): + t.insert().execute() + self.assert_(t.engine.lastrow_has_defaults()) + t.insert().execute() + t.insert().execute() + + l = t.select().execute() + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)]) + + def tearDownAll(self): + t.drop() class SequenceTest(PersistTest):