]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got column defaults to be executeable
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 20:23:37 +0000 (20:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 20:23:37 +0000 (20:23 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
test/defaults.py

index 0f6b659093aab8de3d155fc5493db11867d2bb05..7d158cb7e68e65c1a44c121e7e3b76bebd6b300d 100644 (file)
@@ -596,6 +596,17 @@ class SQLEngine(schema.SchemaEngine):
     def _executemany(self, c, statement, parameters):
         c.executemany(statement, parameters)
         self.context.rowcount = c.rowcount
+
+    def proxy(self, statement=None, parameters=None):
+        executemany = parameters is not None and isinstance(parameters, list)
+
+        if self.positional:
+            if executemany:
+                parameters = [p.values() for p in parameters]
+            else:
+                parameters = parameters.values()
+
+        return self.execute(statement, parameters)
     
     def log(self, msg):
         """logs a message using this SQLEngine's logger stream."""
index 17e421f22805d7fe49d0a919f600e39e5e0058d0..57ae7ba5af87642f7315b16f6a1e14ee3da8406f 100644 (file)
@@ -283,6 +283,7 @@ class Column(sql.ColumnClause, SchemaItem):
         self.default = kwargs.pop('default', None)
         self.index = kwargs.pop('index', None)
         self.unique = kwargs.pop('unique', None)
+        self.onupdate = kwargs.pop('onupdate', None)
         if self.index is not None and self.unique is not None:
             raise ArgumentError("Column may not define both index and unique")
         self._foreign_key = None
@@ -302,7 +303,7 @@ class Column(sql.ColumnClause, SchemaItem):
        return "Column(%s)" % string.join(
         [repr(self.name)] + [repr(self.type)] +
         [repr(x) for x in [self.foreign_key] if x is not None] +
-        ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default']]
+        ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']]
        , ',')
         
     def append_item(self, item):
@@ -326,6 +327,9 @@ class Column(sql.ColumnClause, SchemaItem):
         if self.default is not None:
             self.default = ColumnDefault(self.default)
             self._init_items(self.default)
+        if self.onupdate is not None:
+            self.onupdate = ColumnDefault(self.onupdate, for_update=True)
+            self._init_items(self.onupdate)
         self._init_items(*self.args)
         self.args = None
 
@@ -435,17 +439,26 @@ class ForeignKey(SchemaItem):
 
 class DefaultGenerator(SchemaItem):
     """Base class for column "default" values."""
+    def __init__(self, for_update=False, engine=None):
+        self.for_update = for_update
+        self.engine = engine
     def _set_parent(self, column):
         self.column = column
-        self.column.default = self
+        if self.engine is None:
+            self.engine = column.table.engine
+        if self.for_update:
+            self.column.onupdate = self
+        else:
+            self.column.default = self
     def execute(self):
-        return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute))
+        return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.proxy))
     def __repr__(self):
         return "DefaultGenerator()"
 
 class PassiveDefault(DefaultGenerator):
     """a default that takes effect on the database side"""
-    def __init__(self, arg):
+    def __init__(self, arg, **kwargs):
+        super(PassiveDefault, self).__init__(**kwargs)
         self.arg = arg
     def accept_schema_visitor(self, visitor):
         return visitor.visit_passive_default(self)
@@ -455,7 +468,8 @@ class PassiveDefault(DefaultGenerator):
 class ColumnDefault(DefaultGenerator):
     """A plain default value on a column.  this could correspond to a constant, 
     a callable function, or a SQL clause."""
-    def __init__(self, arg):
+    def __init__(self, arg, **kwargs):
+        super(ColumnDefault, self).__init__(**kwargs)
         self.arg = arg
     def accept_schema_visitor(self, visitor):
         """calls the visit_column_default method on the given visitor."""
@@ -465,12 +479,12 @@ class ColumnDefault(DefaultGenerator):
         
 class Sequence(DefaultGenerator):
     """represents a sequence, which applies to Oracle and Postgres databases."""
-    def __init__(self, name, start = None, increment = None, optional=False, engine=None):
+    def __init__(self, name, start = None, increment = None, optional=False, **kwargs):
+        super(Sequence, self).__init__(**kwargs)
         self.name = name
         self.start = start
         self.increment = increment
         self.optional=optional
-        self.engine = engine
     def __repr__(self):
         return "Sequence(%s)" % string.join(
              [repr(self.name)] +
@@ -479,8 +493,6 @@ class Sequence(DefaultGenerator):
     def _set_parent(self, column):
         super(Sequence, self)._set_parent(column)
         column.sequence = self
-        if self.engine is None:
-            self.engine = column.table.engine
     def create(self):
        self.engine.create(self)
        return self
index fcf852a86a77b3a0a0a8b2b5743c136579b86ca1..459b3abfe91fb5571a903f55e6a22a20386c38ba 100644 (file)
@@ -10,7 +10,8 @@ db = testbase.db
 
 class DefaultTest(PersistTest):
 
-    def testdefaults(self):
+    def setUpAll(self):
+        global t, f, ts
         x = {'x':50}
         def mydefault():
             x['x'] += 1
@@ -56,16 +57,26 @@ class DefaultTest(PersistTest):
             Column('col5', deftype, PassiveDefault(def2))
         )
         t.create()
-        try:
-            t.insert().execute()
-            self.assert_(t.engine.lastrow_has_defaults())
-            t.insert().execute()
-            t.insert().execute()
+
+    def teststandalonedefaults(self):
+        x = t.c.col1.default.execute()
+        y = t.c.col2.default.execute()
+        z = t.c.col3.default.execute()
+        self.assert_(50 <= x <= 57)
+        self.assert_(y == 'imthedefault')
+        self.assert_(z == 6)
         
-            l = t.select().execute()
-            self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
-        finally:
-            t.drop()
+    def testinsertdefaults(self):
+        t.insert().execute()
+        self.assert_(t.engine.lastrow_has_defaults())
+        t.insert().execute()
+        t.insert().execute()
+    
+        l = t.select().execute()
+        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
+
+    def tearDownAll(self):
+        t.drop()
 
 class SequenceTest(PersistTest):