]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got column onupdate working
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Mar 2006 20:31:44 +0000 (20:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Mar 2006 20:31:44 +0000 (20:31 +0000)
improvement to Function so that they can more easily be called standalone without having to throw them into a select().

doc/build/content/sqlconstruction.myt
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/defaults.py

index c386705062ff3f6be1bb08f305abb04b02649bc3..065ef2bcc89d57c5cfffd4c010b5dfb27c99b3d4 100644 (file)
@@ -341,6 +341,17 @@ WHERE substr(users.user_name, :substr) = :substr_1
             </&>
 
         </&>
+        <p>Functions also are callable as standalone values:</p>
+        <&|formatting.myt:code &>
+            # call the "now()" function
+            time = func.now(engine=myengine).scalar()
+            
+            # call myfunc(1,2,3)
+            myvalue = func.myfunc(1, 2, 3, engine=db).execute()
+            
+            # or call them off the engine
+            db.func.now().scalar()
+        </&>
         </&>
         <&|doclib.myt:item, name="literals", description="Literals" &>
         <p>You can drop in a literal value anywhere there isnt a column to attach to via the <span class="codeline">literal</span> keyword:</p>
index 7c0002aa58958cb89f94689b95641f8d049ac91a..7b39d5358e2529ba8b0bc1420a10f259471585c0 100644 (file)
@@ -15,6 +15,18 @@ from sqlalchemy.sql import *
 from sqlalchemy.util import *
 import string, re
 
+ANSI_FUNCS = HashSet([
+'CURRENT_TIME',
+'CURRENT_TIMESTAMP',
+'CURRENT_DATE',
+'LOCAL_TIME',
+'LOCAL_TIMESTAMP',
+'CURRENT_USER',
+'SESSION_USER',
+'USER'
+])
+
+
 def engine(**params):
     return ANSISQLEngine(**params)
 
@@ -57,6 +69,7 @@ class ANSICompiler(sql.Compiled):
         self.select_stack = []
         self.typemap = typemap or {}
         self.isinsert = False
+        self.isupdate = False
         self.bindtemplate = ":%s"
         if engine is not None:
             self.paramstyle = engine.paramstyle
@@ -89,7 +102,7 @@ class ANSICompiler(sql.Compiled):
                 self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement])
 
     def get_from_text(self, obj):
-        return self.froms[obj]
+        return self.froms.get(obj, None)
 
     def get_str(self, obj):
         return self.strings[obj]
@@ -158,6 +171,11 @@ class ANSICompiler(sql.Compiled):
         else:
             return parameters
     
+    def default_from(self):
+        """called when a SELECT statement has no froms, and no FROM clause is to be appended.  
+        gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+        return ""
+
     def visit_label(self, label):
         if len(self.select_stack):
             self.typemap.setdefault(label.name.lower(), label.obj.type)
@@ -211,7 +229,12 @@ class ANSICompiler(sql.Compiled):
             self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
 
     def visit_function(self, func):
-        self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+        if len(self.select_stack):
+            self.typemap.setdefault(func.name, func.type)
+        if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
+            self.strings[func] = func.name
+        else:
+            self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
         
     def visit_compound_select(self, cs):
         text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
@@ -325,7 +348,9 @@ class ANSICompiler(sql.Compiled):
         if len(froms):
             text += " \nFROM "
             text += string.join(froms, ', ')
-
+        else:
+            text += self.default_from()
+            
         if whereclause is not None:
             t = self.get_str(whereclause)
             if t:
@@ -384,21 +409,33 @@ class ANSICompiler(sql.Compiled):
 
     def visit_insert_column_default(self, column, default):
         """called when visiting an Insert statement, for each column in the table that
-        contains a ColumnDefault object."""
+        contains a ColumnDefault object.  adds a blank 'placeholder' parameter so the 
+        Insert gets compiled with this column's name in its column and VALUES clauses."""
+        self.parameters.setdefault(column.key, None)
+
+    def visit_update_column_default(self, column, default):
+        """called when visiting an Update statement, for each column in the table that
+        contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the 
+        Update gets compiled with this column's name as one of its SET clauses."""
         self.parameters.setdefault(column.key, None)
         
     def visit_insert_sequence(self, column, sequence):
         """called when visiting an Insert statement, for each column in the table that
-        contains a Sequence object."""
+        contains a Sequence object.  Overridden by compilers that support sequences to place
+        a blank 'placeholder' parameter, so the Insert gets compiled with this column's
+        name in its column and VALUES clauses."""
         pass
     
     def visit_insert_column(self, column):
         """called when visiting an Insert statement, for each column in the table
-        that is a NULL insert into the table"""
+        that is a NULL insert into the table.  Overridden by compilers who disallow
+        NULL columns being set in an Insert where there is a default value on the column
+        (i.e. postgres), to remove the column from the parameter list."""
         pass
         
     def visit_insert(self, insert_stmt):
-        # set up a call for the defaults and sequences inside the table
+        # scan the table's columns for defaults that have to be pre-set for an INSERT
+        # add these columns to the parameter list via visit_insert_XXX methods
         class DefaultVisitor(schema.SchemaVisitor):
             def visit_column(s, c):
                 self.visit_insert_column(c)
@@ -424,6 +461,17 @@ class ANSICompiler(sql.Compiled):
         self.strings[insert_stmt] = text
 
     def visit_update(self, update_stmt):
+        # scan the table's columns for onupdates that have to be pre-set for an UPDATE
+        # add these columns to the parameter list via visit_update_XXX methods
+        class OnUpdateVisitor(schema.SchemaVisitor):
+            def visit_column_onupdate(s, cd):
+                self.visit_update_column_default(c, cd)
+        vis = OnUpdateVisitor()
+        for c in update_stmt.table.c:
+            if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
+                c.accept_schema_visitor(vis)
+
+        self.isupdate = True
         colparams = self._get_colparams(update_stmt)
         def create_param(p):
             if isinstance(p, sql.BindParamClause):
index 6f5e98265cadd352fbf4dbaa3a5f73b7bde43671..eab200317b4fde58b86b56cfb840366273c44ccb 100644 (file)
@@ -209,6 +209,11 @@ class OracleCompiler(ansisql.ANSICompiler):
         self._use_ansi = use_ansi
         ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs)
         
+    def default_from(self):
+        """called when a SELECT statement has no froms, and no FROM clause is to be appended.  
+        gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+        return " FROM DUAL"
+
     def visit_join(self, join):
         if self._use_ansi:
             return ansisql.ANSICompiler.visit_join(self, join)
index 105fe7a76fe0410052c5cb32acbe2f476397d6a7..db20b636c33dc4fa8e2564f2916307940d5a363e 100644 (file)
@@ -103,16 +103,6 @@ class PGBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOLEAN"
 
-ANSI_FUNCS = util.HashSet([
-'CURRENT_TIME',
-'CURRENT_TIMESTAMP',
-'CURRENT_DATE',
-'LOCAL_TIME',
-'LOCAL_TIMESTAMP',
-'CURRENT_USER',
-'SESSION_USER',
-'USER'
-])
 
 pg2_colspecs = {
     sqltypes.Integer : PGInteger,
@@ -283,12 +273,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
 
 class PGCompiler(ansisql.ANSICompiler):
 
-    def visit_function(self, func):
-        # PG has a bunch of funcs that explicitly need no parenthesis
-        if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
-            self.strings[func] = func.name
-        else:
-            super(PGCompiler, self).visit_function(func)
         
     def visit_insert_column(self, column):
         # Postgres advises against OID usage and turns it off in 8.1,
index 7d158cb7e68e65c1a44c121e7e3b76bebd6b300d..3703169fa03288c03273847f57f8d85c1c9071c5 100644 (file)
@@ -135,6 +135,12 @@ class DefaultRunner(schema.SchemaVisitor):
         else:
             return None
 
+    def get_column_onupdate(self, column):
+        if column.onupdate is not None:
+            return column.onupdate.accept_schema_visitor(self)
+        else:
+            return None
+        
     def visit_passive_default(self, default):
         """passive defaults by definition return None on the app side,
         and are post-fetched to get the DB-side value"""
@@ -147,7 +153,15 @@ class DefaultRunner(schema.SchemaVisitor):
     def exec_default_sql(self, default):
         c = sql.select([default.arg], engine=self.engine).compile()
         return self.proxy(str(c), c.get_params()).fetchone()[0]
-        
+    
+    def visit_column_onupdate(self, onupdate):
+        if isinstance(onupdate.arg, sql.ClauseElement):
+            return self.exec_default_sql(onupdate)
+        elif callable(onupdate.arg):
+            return onupdate.arg()
+        else:
+            return onupdate.arg
+            
     def visit_column_default(self, default):
         if isinstance(default.arg, sql.ClauseElement):
             return self.exec_default_sql(default)
@@ -245,6 +259,13 @@ class SQLEngine(schema.SchemaEngine):
             typeobj = typeobj()
         return typeobj
 
+    def _func(self):
+        class FunctionGateway(object):
+            def __getattr__(s, name):
+                return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs)
+        return FunctionGateway()
+    func = property(_func)
+    
     def text(self, text, *args, **kwargs):
         """returns a sql.text() object for performing literal queries."""
         return sql.text(text, engine=self, *args, **kwargs)
@@ -426,6 +447,15 @@ class SQLEngine(schema.SchemaEngine):
                 self.context.tcount = None
 
     def _process_defaults(self, proxy, compiled, parameters, **kwargs):
+        """INSERT and UPDATE statements, when compiled, may have additional columns added to their
+        VALUES and SET lists corresponding to column defaults/onupdates that are present on the 
+        Table object (i.e. ColumnDefault, Sequence, PassiveDefault).  This method pre-execs those
+        DefaultGenerator objects that require pre-execution and sets their values within the 
+        parameter list, and flags the thread-local state about
+        PassiveDefault objects that may require post-fetching the row after it is inserted/updated.  
+        This method relies upon logic within the ANSISQLCompiler in its visit_insert and 
+        visit_update methods that add the appropriate column clauses to the statement when its 
+        being compiled, so that these parameters can be bound to the statement."""
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
             if isinstance(parameters, list):
@@ -454,7 +484,19 @@ class SQLEngine(schema.SchemaEngine):
                     self.context.last_inserted_ids = None
                 else:
                     self.context.last_inserted_ids = last_inserted_ids
-
+        elif getattr(compiled, 'isupdate', False):
+            if isinstance(parameters, list):
+                plist = parameters
+            else:
+                plist = [parameters]
+            drunner = self.defaultrunner(proxy)
+            for param in plist:
+                for c in compiled.statement.table.c:
+                    if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
+                        value = drunner.get_column_onupdate(c)
+                        if value is not None:
+                            param[c.name] = value
+                        
     def lastrow_has_defaults(self):
         return self.context.lastrow_has_defaults
         
index 57ae7ba5af87642f7315b16f6a1e14ee3da8406f..5cb9f20430a0c675d84a252c59bc08b0c99f49b9 100644 (file)
@@ -364,6 +364,8 @@ class Column(sql.ColumnClause, SchemaItem):
         then calls visit_column on the visitor."""
         if self.default is not None:
             self.default.accept_schema_visitor(visitor)
+        if self.onupdate is not None:
+            self.onupdate.accept_schema_visitor(visitor)
         if self.foreign_key is not None:
             self.foreign_key.accept_schema_visitor(visitor)
         visitor.visit_column(self)
@@ -473,7 +475,10 @@ class ColumnDefault(DefaultGenerator):
         self.arg = arg
     def accept_schema_visitor(self, visitor):
         """calls the visit_column_default method on the given visitor."""
-        return visitor.visit_column_default(self)
+        if self.for_update:
+            return visitor.visit_column_onupdate(self)
+        else:
+            return visitor.visit_column_default(self)
     def __repr__(self):
         return "ColumnDefault(%s)" % repr(self.arg)
         
@@ -599,6 +604,9 @@ class SchemaVisitor(sql.ClauseVisitor):
     def visit_column_default(self, default):
         """visit a ColumnDefault."""
         pass
+    def visit_column_onupdate(self, onupdate):
+        """visit a ColumnDefault with the "for_update" flag set."""
+        pass
     def visit_sequence(self, sequence):
         """visit a Sequence."""
         pass
index f05310e425ccc125ce815fbd111bd4cd0fd8969b..cee328b53a298232cf209426ec9d79d6f9274b2b 100644 (file)
@@ -762,6 +762,9 @@ class Function(ClauseList, ColumnElement):
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
         self.type = kwargs.get('type', sqltypes.NULLTYPE)
+        self._engine = kwargs.get('engine', None)
+        if self._engine is not None:
+            self.type = self._engine.type_descriptor(self.type)
         ClauseList.__init__(self, parens=True, *clauses)
     key = property(lambda self:self.name)
     def append(self, clause):
@@ -771,6 +774,8 @@ class Function(ClauseList, ColumnElement):
             else:
                 clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
         self.clauses.append(clause)
+    def _process_from_dict(self, data, asfrom):
+        data.setdefault(self, self)
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
         return Function(self.name, type=self.type, *clauses)
@@ -782,6 +787,10 @@ class Function(ClauseList, ColumnElement):
         return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
     def select(self):
         return select([self])
+    def scalar(self):
+        return select([self]).scalar()
+    def execute(self):
+        return select([self]).execute()
     def _compare_type(self, obj):
         return self.type
                 
index 459b3abfe91fb5571a903f55e6a22a20386c38ba..c2c8877eb12ebe8cba02d70e9fa0c7ed88785df5 100644 (file)
@@ -7,11 +7,11 @@ from sqlalchemy import *
 import sqlalchemy
 
 db = testbase.db
-
+testbase.echo=False
 class DefaultTest(PersistTest):
 
     def setUpAll(self):
-        global t, f, ts
+        global t, f, ts, currenttime
         x = {'x':50}
         def mydefault():
             x['x'] += 1
@@ -22,18 +22,19 @@ class DefaultTest(PersistTest):
  
         # select "count(1)" from the DB which returns different results
         # on different DBs
+        currenttime = db.func.current_date(type=Date);
         if is_oracle:
-            f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
-            ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
-            def1 = func.sysdate()
+            ts = db.func.sysdate().scalar()
+            f = select([func.count(1) + 5], engine=db).scalar()
+            def1 = currenttime
             def2 = text("sysdate")
             deftype = Date
         elif use_function_defaults:
             f = select([func.count(1) + 5], engine=db).scalar()
-            def1 = func.current_date()
+            def1 = currenttime
             def2 = text("current_date")
             deftype = Date
-            ts = select([func.current_date()], engine=db).scalar()
+            ts = db.func.current_date().scalar()
         else:
             f = select([func.count(1) + 5], engine=db).scalar()
             def1 = def2 = "3"
@@ -45,20 +46,29 @@ class DefaultTest(PersistTest):
             Column('col1', Integer, primary_key=True, default=mydefault),
             
             # python literal
-            Column('col2', String(20), default="imthedefault"),
+            Column('col2', String(20), default="imthedefault", onupdate="im the update"),
             
             # preexecute expression
-            Column('col3', Integer, default=func.count(1) + 5),
+            Column('col3', Integer, default=func.count(1) + 5, onupdate=func.count(1) + 14),
             
             # SQL-side default from sql expression
             Column('col4', deftype, PassiveDefault(def1)),
             
             # SQL-side default from literal expression
-            Column('col5', deftype, PassiveDefault(def2))
+            Column('col5', deftype, PassiveDefault(def2)),
+            
+            # preexecute + update timestamp
+            Column('col6', Date, default=currenttime, onupdate=currenttime)
         )
         t.create()
 
-    def teststandalonedefaults(self):
+    def tearDownAll(self):
+        t.drop()
+    
+    def tearDown(self):
+        t.delete().execute()
+        
+    def teststandalone(self):
         x = t.c.col1.default.execute()
         y = t.c.col2.default.execute()
         z = t.c.col3.default.execute()
@@ -66,18 +76,27 @@ class DefaultTest(PersistTest):
         self.assert_(y == 'imthedefault')
         self.assert_(z == 6)
         
-    def testinsertdefaults(self):
+    def testinsert(self):
         t.insert().execute()
         self.assert_(t.engine.lastrow_has_defaults())
         t.insert().execute()
         t.insert().execute()
-    
-        l = t.select().execute()
-        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
 
-    def tearDownAll(self):
-        t.drop()
+        ctexec = currenttime.scalar()
+        self.echo("Currenttime "+ repr(ctexec))
+        l = t.select().execute()
+        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)])
 
+    def testupdate(self):
+        t.insert().execute()
+        pk = t.engine.last_inserted_ids()[0]
+        t.update(t.c.col1==pk).execute(col4=None, col5=None)
+        ctexec = currenttime.scalar()
+        self.echo("Currenttime "+ repr(ctexec))
+        l = t.select(t.c.col1==pk).execute()
+        l = l.fetchone()
+        self.assert_(l == (pk, 'im the update', 15, None, None, ctexec))
+        
 class SequenceTest(PersistTest):
 
     def setUpAll(self):