improvement to Function so that they can more easily be called standalone without having to throw them into a select().
</&>
</&>
+ <p>Functions also are callable as standalone values:</p>
+ <&|formatting.myt:code &>
+ # call the "now()" function
+ time = func.now(engine=myengine).scalar()
+
+ # call myfunc(1,2,3)
+ myvalue = func.myfunc(1, 2, 3, engine=db).execute()
+
+ # or call them off the engine
+ db.func.now().scalar()
+ </&>
</&>
<&|doclib.myt:item, name="literals", description="Literals" &>
<p>You can drop in a literal value anywhere there isnt a column to attach to via the <span class="codeline">literal</span> keyword:</p>
from sqlalchemy.util import *
import string, re
+ANSI_FUNCS = HashSet([
+'CURRENT_TIME',
+'CURRENT_TIMESTAMP',
+'CURRENT_DATE',
+'LOCAL_TIME',
+'LOCAL_TIMESTAMP',
+'CURRENT_USER',
+'SESSION_USER',
+'USER'
+])
+
+
def engine(**params):
return ANSISQLEngine(**params)
self.select_stack = []
self.typemap = typemap or {}
self.isinsert = False
+ self.isupdate = False
self.bindtemplate = ":%s"
if engine is not None:
self.paramstyle = engine.paramstyle
self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement])
def get_from_text(self, obj):
- return self.froms[obj]
+ return self.froms.get(obj, None)
def get_str(self, obj):
return self.strings[obj]
else:
return parameters
+ def default_from(self):
+ """called when a SELECT statement has no froms, and no FROM clause is to be appended.
+ gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ return ""
+
def visit_label(self, label):
if len(self.select_stack):
self.typemap.setdefault(label.name.lower(), label.obj.type)
self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
def visit_function(self, func):
- self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+ if len(self.select_stack):
+ self.typemap.setdefault(func.name, func.type)
+ if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
+ self.strings[func] = func.name
+ else:
+ self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
def visit_compound_select(self, cs):
text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
if len(froms):
text += " \nFROM "
text += string.join(froms, ', ')
-
+ else:
+ text += self.default_from()
+
if whereclause is not None:
t = self.get_str(whereclause)
if t:
def visit_insert_column_default(self, column, default):
"""called when visiting an Insert statement, for each column in the table that
- contains a ColumnDefault object."""
+ contains a ColumnDefault object. adds a blank 'placeholder' parameter so the
+ Insert gets compiled with this column's name in its column and VALUES clauses."""
+ self.parameters.setdefault(column.key, None)
+
+ def visit_update_column_default(self, column, default):
+ """called when visiting an Update statement, for each column in the table that
+ contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the
+ Update gets compiled with this column's name as one of its SET clauses."""
self.parameters.setdefault(column.key, None)
def visit_insert_sequence(self, column, sequence):
"""called when visiting an Insert statement, for each column in the table that
- contains a Sequence object."""
+ contains a Sequence object. Overridden by compilers that support sequences to place
+ a blank 'placeholder' parameter, so the Insert gets compiled with this column's
+ name in its column and VALUES clauses."""
pass
def visit_insert_column(self, column):
"""called when visiting an Insert statement, for each column in the table
- that is a NULL insert into the table"""
+ that is a NULL insert into the table. Overridden by compilers who disallow
+ NULL columns being set in an Insert where there is a default value on the column
+ (i.e. postgres), to remove the column from the parameter list."""
pass
def visit_insert(self, insert_stmt):
- # set up a call for the defaults and sequences inside the table
+ # scan the table's columns for defaults that have to be pre-set for an INSERT
+ # add these columns to the parameter list via visit_insert_XXX methods
class DefaultVisitor(schema.SchemaVisitor):
def visit_column(s, c):
self.visit_insert_column(c)
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
+ # scan the table's columns for onupdates that have to be pre-set for an UPDATE
+ # add these columns to the parameter list via visit_update_XXX methods
+ class OnUpdateVisitor(schema.SchemaVisitor):
+ def visit_column_onupdate(s, cd):
+ self.visit_update_column_default(c, cd)
+ vis = OnUpdateVisitor()
+ for c in update_stmt.table.c:
+ if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
+ c.accept_schema_visitor(vis)
+
+ self.isupdate = True
colparams = self._get_colparams(update_stmt)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self._use_ansi = use_ansi
ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs)
+ def default_from(self):
+ """called when a SELECT statement has no froms, and no FROM clause is to be appended.
+ gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ return " FROM DUAL"
+
def visit_join(self, join):
if self._use_ansi:
return ansisql.ANSICompiler.visit_join(self, join)
def get_col_spec(self):
return "BOOLEAN"
-ANSI_FUNCS = util.HashSet([
-'CURRENT_TIME',
-'CURRENT_TIMESTAMP',
-'CURRENT_DATE',
-'LOCAL_TIME',
-'LOCAL_TIMESTAMP',
-'CURRENT_USER',
-'SESSION_USER',
-'USER'
-])
pg2_colspecs = {
sqltypes.Integer : PGInteger,
class PGCompiler(ansisql.ANSICompiler):
- def visit_function(self, func):
- # PG has a bunch of funcs that explicitly need no parenthesis
- if func.name.upper() in ANSI_FUNCS and not len(func.clauses):
- self.strings[func] = func.name
- else:
- super(PGCompiler, self).visit_function(func)
def visit_insert_column(self, column):
# Postgres advises against OID usage and turns it off in 8.1,
else:
return None
+ def get_column_onupdate(self, column):
+ if column.onupdate is not None:
+ return column.onupdate.accept_schema_visitor(self)
+ else:
+ return None
+
def visit_passive_default(self, default):
"""passive defaults by definition return None on the app side,
and are post-fetched to get the DB-side value"""
def exec_default_sql(self, default):
c = sql.select([default.arg], engine=self.engine).compile()
return self.proxy(str(c), c.get_params()).fetchone()[0]
-
+
+ def visit_column_onupdate(self, onupdate):
+ if isinstance(onupdate.arg, sql.ClauseElement):
+ return self.exec_default_sql(onupdate)
+ elif callable(onupdate.arg):
+ return onupdate.arg()
+ else:
+ return onupdate.arg
+
def visit_column_default(self, default):
if isinstance(default.arg, sql.ClauseElement):
return self.exec_default_sql(default)
typeobj = typeobj()
return typeobj
+ def _func(self):
+ class FunctionGateway(object):
+ def __getattr__(s, name):
+ return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs)
+ return FunctionGateway()
+ func = property(_func)
+
def text(self, text, *args, **kwargs):
"""returns a sql.text() object for performing literal queries."""
return sql.text(text, engine=self, *args, **kwargs)
self.context.tcount = None
def _process_defaults(self, proxy, compiled, parameters, **kwargs):
+ """INSERT and UPDATE statements, when compiled, may have additional columns added to their
+ VALUES and SET lists corresponding to column defaults/onupdates that are present on the
+ Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those
+ DefaultGenerator objects that require pre-execution and sets their values within the
+ parameter list, and flags the thread-local state about
+ PassiveDefault objects that may require post-fetching the row after it is inserted/updated.
+ This method relies upon logic within the ANSISQLCompiler in its visit_insert and
+ visit_update methods that add the appropriate column clauses to the statement when its
+ being compiled, so that these parameters can be bound to the statement."""
if compiled is None: return
if getattr(compiled, "isinsert", False):
if isinstance(parameters, list):
self.context.last_inserted_ids = None
else:
self.context.last_inserted_ids = last_inserted_ids
-
+ elif getattr(compiled, 'isupdate', False):
+ if isinstance(parameters, list):
+ plist = parameters
+ else:
+ plist = [parameters]
+ drunner = self.defaultrunner(proxy)
+ for param in plist:
+ for c in compiled.statement.table.c:
+ if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
+ value = drunner.get_column_onupdate(c)
+ if value is not None:
+ param[c.name] = value
+
def lastrow_has_defaults(self):
return self.context.lastrow_has_defaults
then calls visit_column on the visitor."""
if self.default is not None:
self.default.accept_schema_visitor(visitor)
+ if self.onupdate is not None:
+ self.onupdate.accept_schema_visitor(visitor)
if self.foreign_key is not None:
self.foreign_key.accept_schema_visitor(visitor)
visitor.visit_column(self)
self.arg = arg
def accept_schema_visitor(self, visitor):
"""calls the visit_column_default method on the given visitor."""
- return visitor.visit_column_default(self)
+ if self.for_update:
+ return visitor.visit_column_onupdate(self)
+ else:
+ return visitor.visit_column_default(self)
def __repr__(self):
return "ColumnDefault(%s)" % repr(self.arg)
def visit_column_default(self, default):
"""visit a ColumnDefault."""
pass
+ def visit_column_onupdate(self, onupdate):
+ """visit a ColumnDefault with the "for_update" flag set."""
+ pass
def visit_sequence(self, sequence):
"""visit a Sequence."""
pass
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = kwargs.get('type', sqltypes.NULLTYPE)
+ self._engine = kwargs.get('engine', None)
+ if self._engine is not None:
+ self.type = self._engine.type_descriptor(self.type)
ClauseList.__init__(self, parens=True, *clauses)
key = property(lambda self:self.name)
def append(self, clause):
else:
clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
self.clauses.append(clause)
+ def _process_from_dict(self, data, asfrom):
+ data.setdefault(self, self)
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return Function(self.name, type=self.type, *clauses)
return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
def select(self):
return select([self])
+ def scalar(self):
+ return select([self]).scalar()
+ def execute(self):
+ return select([self]).execute()
def _compare_type(self, obj):
return self.type
import sqlalchemy
db = testbase.db
-
+testbase.echo=False
class DefaultTest(PersistTest):
def setUpAll(self):
- global t, f, ts
+ global t, f, ts, currenttime
x = {'x':50}
def mydefault():
x['x'] += 1
# select "count(1)" from the DB which returns different results
# on different DBs
+ currenttime = db.func.current_date(type=Date);
if is_oracle:
- f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
- ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
- def1 = func.sysdate()
+ ts = db.func.sysdate().scalar()
+ f = select([func.count(1) + 5], engine=db).scalar()
+ def1 = currenttime
def2 = text("sysdate")
deftype = Date
elif use_function_defaults:
f = select([func.count(1) + 5], engine=db).scalar()
- def1 = func.current_date()
+ def1 = currenttime
def2 = text("current_date")
deftype = Date
- ts = select([func.current_date()], engine=db).scalar()
+ ts = db.func.current_date().scalar()
else:
f = select([func.count(1) + 5], engine=db).scalar()
def1 = def2 = "3"
Column('col1', Integer, primary_key=True, default=mydefault),
# python literal
- Column('col2', String(20), default="imthedefault"),
+ Column('col2', String(20), default="imthedefault", onupdate="im the update"),
# preexecute expression
- Column('col3', Integer, default=func.count(1) + 5),
+ Column('col3', Integer, default=func.count(1) + 5, onupdate=func.count(1) + 14),
# SQL-side default from sql expression
Column('col4', deftype, PassiveDefault(def1)),
# SQL-side default from literal expression
- Column('col5', deftype, PassiveDefault(def2))
+ Column('col5', deftype, PassiveDefault(def2)),
+
+ # preexecute + update timestamp
+ Column('col6', Date, default=currenttime, onupdate=currenttime)
)
t.create()
- def teststandalonedefaults(self):
+ def tearDownAll(self):
+ t.drop()
+
+ def tearDown(self):
+ t.delete().execute()
+
+ def teststandalone(self):
x = t.c.col1.default.execute()
y = t.c.col2.default.execute()
z = t.c.col3.default.execute()
self.assert_(y == 'imthedefault')
self.assert_(z == 6)
- def testinsertdefaults(self):
+ def testinsert(self):
t.insert().execute()
self.assert_(t.engine.lastrow_has_defaults())
t.insert().execute()
t.insert().execute()
-
- l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
- def tearDownAll(self):
- t.drop()
+ ctexec = currenttime.scalar()
+ self.echo("Currenttime "+ repr(ctexec))
+ l = t.select().execute()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)])
+ def testupdate(self):
+ t.insert().execute()
+ pk = t.engine.last_inserted_ids()[0]
+ t.update(t.c.col1==pk).execute(col4=None, col5=None)
+ ctexec = currenttime.scalar()
+ self.echo("Currenttime "+ repr(ctexec))
+ l = t.select(t.c.col1==pk).execute()
+ l = l.fetchone()
+ self.assert_(l == (pk, 'im the update', 15, None, None, ctexec))
+
class SequenceTest(PersistTest):
def setUpAll(self):