From 9c4f3c0480f54e08b3aa2800ed76e89f957f8131 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 5 Mar 2006 20:31:44 +0000 Subject: [PATCH] got column onupdate working improvement to Function so that they can more easily be called standalone without having to throw them into a select(). --- doc/build/content/sqlconstruction.myt | 11 +++++ lib/sqlalchemy/ansisql.py | 62 ++++++++++++++++++++++++--- lib/sqlalchemy/databases/oracle.py | 5 +++ lib/sqlalchemy/databases/postgres.py | 16 ------- lib/sqlalchemy/engine.py | 46 +++++++++++++++++++- lib/sqlalchemy/schema.py | 10 ++++- lib/sqlalchemy/sql.py | 9 ++++ test/defaults.py | 53 +++++++++++++++-------- 8 files changed, 169 insertions(+), 43 deletions(-) diff --git a/doc/build/content/sqlconstruction.myt b/doc/build/content/sqlconstruction.myt index c386705062..065ef2bcc8 100644 --- a/doc/build/content/sqlconstruction.myt +++ b/doc/build/content/sqlconstruction.myt @@ -341,6 +341,17 @@ WHERE substr(users.user_name, :substr) = :substr_1 +

Functions also are callable as standalone values:

+ <&|formatting.myt:code &> + # call the "now()" function + time = func.now(engine=myengine).scalar() + + # call myfunc(1,2,3) + myvalue = func.myfunc(1, 2, 3, engine=db).execute() + + # or call them off the engine + db.func.now().scalar() + <&|doclib.myt:item, name="literals", description="Literals" &>

You can drop in a literal value anywhere there isnt a column to attach to via the literal keyword:

diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 7c0002aa58..7b39d5358e 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -15,6 +15,18 @@ from sqlalchemy.sql import * from sqlalchemy.util import * import string, re +ANSI_FUNCS = HashSet([ +'CURRENT_TIME', +'CURRENT_TIMESTAMP', +'CURRENT_DATE', +'LOCAL_TIME', +'LOCAL_TIMESTAMP', +'CURRENT_USER', +'SESSION_USER', +'USER' +]) + + def engine(**params): return ANSISQLEngine(**params) @@ -57,6 +69,7 @@ class ANSICompiler(sql.Compiled): self.select_stack = [] self.typemap = typemap or {} self.isinsert = False + self.isupdate = False self.bindtemplate = ":%s" if engine is not None: self.paramstyle = engine.paramstyle @@ -89,7 +102,7 @@ class ANSICompiler(sql.Compiled): self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement]) def get_from_text(self, obj): - return self.froms[obj] + return self.froms.get(obj, None) def get_str(self, obj): return self.strings[obj] @@ -158,6 +171,11 @@ class ANSICompiler(sql.Compiled): else: return parameters + def default_from(self): + """called when a SELECT statement has no froms, and no FROM clause is to be appended. + gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ + return "" + def visit_label(self, label): if len(self.select_stack): self.typemap.setdefault(label.name.lower(), label.obj.type) @@ -211,7 +229,12 @@ class ANSICompiler(sql.Compiled): self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') def visit_function(self, func): - self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" + if len(self.select_stack): + self.typemap.setdefault(func.name, func.type) + if func.name.upper() in ANSI_FUNCS and not len(func.clauses): + self.strings[func] = func.name + else: + self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" def visit_compound_select(self, cs): text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") @@ -325,7 +348,9 @@ class ANSICompiler(sql.Compiled): if len(froms): text += " \nFROM " text += string.join(froms, ', ') - + else: + text += self.default_from() + if whereclause is not None: t = self.get_str(whereclause) if t: @@ -384,21 +409,33 @@ class ANSICompiler(sql.Compiled): def visit_insert_column_default(self, column, default): """called when visiting an Insert statement, for each column in the table that - contains a ColumnDefault object.""" + contains a ColumnDefault object. adds a blank 'placeholder' parameter so the + Insert gets compiled with this column's name in its column and VALUES clauses.""" + self.parameters.setdefault(column.key, None) + + def visit_update_column_default(self, column, default): + """called when visiting an Update statement, for each column in the table that + contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the + Update gets compiled with this column's name as one of its SET clauses.""" self.parameters.setdefault(column.key, None) def visit_insert_sequence(self, column, sequence): """called when visiting an Insert statement, for each column in the table that - contains a Sequence object.""" + contains a Sequence object. Overridden by compilers that support sequences to place + a blank 'placeholder' parameter, so the Insert gets compiled with this column's + name in its column and VALUES clauses.""" pass def visit_insert_column(self, column): """called when visiting an Insert statement, for each column in the table - that is a NULL insert into the table""" + that is a NULL insert into the table. Overridden by compilers who disallow + NULL columns being set in an Insert where there is a default value on the column + (i.e. postgres), to remove the column from the parameter list.""" pass def visit_insert(self, insert_stmt): - # set up a call for the defaults and sequences inside the table + # scan the table's columns for defaults that have to be pre-set for an INSERT + # add these columns to the parameter list via visit_insert_XXX methods class DefaultVisitor(schema.SchemaVisitor): def visit_column(s, c): self.visit_insert_column(c) @@ -424,6 +461,17 @@ class ANSICompiler(sql.Compiled): self.strings[insert_stmt] = text def visit_update(self, update_stmt): + # scan the table's columns for onupdates that have to be pre-set for an UPDATE + # add these columns to the parameter list via visit_update_XXX methods + class OnUpdateVisitor(schema.SchemaVisitor): + def visit_column_onupdate(s, cd): + self.visit_update_column_default(c, cd) + vis = OnUpdateVisitor() + for c in update_stmt.table.c: + if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): + c.accept_schema_visitor(vis) + + self.isupdate = True colparams = self._get_colparams(update_stmt) def create_param(p): if isinstance(p, sql.BindParamClause): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 6f5e98265c..eab200317b 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -209,6 +209,11 @@ class OracleCompiler(ansisql.ANSICompiler): self._use_ansi = use_ansi ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs) + def default_from(self): + """called when a SELECT statement has no froms, and no FROM clause is to be appended. + gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ + return " FROM DUAL" + def visit_join(self, join): if self._use_ansi: return ansisql.ANSICompiler.visit_join(self, join) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 105fe7a76f..db20b636c3 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -103,16 +103,6 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" -ANSI_FUNCS = util.HashSet([ -'CURRENT_TIME', -'CURRENT_TIMESTAMP', -'CURRENT_DATE', -'LOCAL_TIME', -'LOCAL_TIMESTAMP', -'CURRENT_USER', -'SESSION_USER', -'USER' -]) pg2_colspecs = { sqltypes.Integer : PGInteger, @@ -283,12 +273,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine): class PGCompiler(ansisql.ANSICompiler): - def visit_function(self, func): - # PG has a bunch of funcs that explicitly need no parenthesis - if func.name.upper() in ANSI_FUNCS and not len(func.clauses): - self.strings[func] = func.name - else: - super(PGCompiler, self).visit_function(func) def visit_insert_column(self, column): # Postgres advises against OID usage and turns it off in 8.1, diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 7d158cb7e6..3703169fa0 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -135,6 +135,12 @@ class DefaultRunner(schema.SchemaVisitor): else: return None + def get_column_onupdate(self, column): + if column.onupdate is not None: + return column.onupdate.accept_schema_visitor(self) + else: + return None + def visit_passive_default(self, default): """passive defaults by definition return None on the app side, and are post-fetched to get the DB-side value""" @@ -147,7 +153,15 @@ class DefaultRunner(schema.SchemaVisitor): def exec_default_sql(self, default): c = sql.select([default.arg], engine=self.engine).compile() return self.proxy(str(c), c.get_params()).fetchone()[0] - + + def visit_column_onupdate(self, onupdate): + if isinstance(onupdate.arg, sql.ClauseElement): + return self.exec_default_sql(onupdate) + elif callable(onupdate.arg): + return onupdate.arg() + else: + return onupdate.arg + def visit_column_default(self, default): if isinstance(default.arg, sql.ClauseElement): return self.exec_default_sql(default) @@ -245,6 +259,13 @@ class SQLEngine(schema.SchemaEngine): typeobj = typeobj() return typeobj + def _func(self): + class FunctionGateway(object): + def __getattr__(s, name): + return lambda *c, **kwargs: sql.Function(name, engine=self, *c, **kwargs) + return FunctionGateway() + func = property(_func) + def text(self, text, *args, **kwargs): """returns a sql.text() object for performing literal queries.""" return sql.text(text, engine=self, *args, **kwargs) @@ -426,6 +447,15 @@ class SQLEngine(schema.SchemaEngine): self.context.tcount = None def _process_defaults(self, proxy, compiled, parameters, **kwargs): + """INSERT and UPDATE statements, when compiled, may have additional columns added to their + VALUES and SET lists corresponding to column defaults/onupdates that are present on the + Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those + DefaultGenerator objects that require pre-execution and sets their values within the + parameter list, and flags the thread-local state about + PassiveDefault objects that may require post-fetching the row after it is inserted/updated. + This method relies upon logic within the ANSISQLCompiler in its visit_insert and + visit_update methods that add the appropriate column clauses to the statement when its + being compiled, so that these parameters can be bound to the statement.""" if compiled is None: return if getattr(compiled, "isinsert", False): if isinstance(parameters, list): @@ -454,7 +484,19 @@ class SQLEngine(schema.SchemaEngine): self.context.last_inserted_ids = None else: self.context.last_inserted_ids = last_inserted_ids - + elif getattr(compiled, 'isupdate', False): + if isinstance(parameters, list): + plist = parameters + else: + plist = [parameters] + drunner = self.defaultrunner(proxy) + for param in plist: + for c in compiled.statement.table.c: + if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None): + value = drunner.get_column_onupdate(c) + if value is not None: + param[c.name] = value + def lastrow_has_defaults(self): return self.context.lastrow_has_defaults diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 57ae7ba5af..5cb9f20430 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -364,6 +364,8 @@ class Column(sql.ColumnClause, SchemaItem): then calls visit_column on the visitor.""" if self.default is not None: self.default.accept_schema_visitor(visitor) + if self.onupdate is not None: + self.onupdate.accept_schema_visitor(visitor) if self.foreign_key is not None: self.foreign_key.accept_schema_visitor(visitor) visitor.visit_column(self) @@ -473,7 +475,10 @@ class ColumnDefault(DefaultGenerator): self.arg = arg def accept_schema_visitor(self, visitor): """calls the visit_column_default method on the given visitor.""" - return visitor.visit_column_default(self) + if self.for_update: + return visitor.visit_column_onupdate(self) + else: + return visitor.visit_column_default(self) def __repr__(self): return "ColumnDefault(%s)" % repr(self.arg) @@ -599,6 +604,9 @@ class SchemaVisitor(sql.ClauseVisitor): def visit_column_default(self, default): """visit a ColumnDefault.""" pass + def visit_column_onupdate(self, onupdate): + """visit a ColumnDefault with the "for_update" flag set.""" + pass def visit_sequence(self, sequence): """visit a Sequence.""" pass diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index f05310e425..cee328b53a 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -762,6 +762,9 @@ class Function(ClauseList, ColumnElement): def __init__(self, name, *clauses, **kwargs): self.name = name self.type = kwargs.get('type', sqltypes.NULLTYPE) + self._engine = kwargs.get('engine', None) + if self._engine is not None: + self.type = self._engine.type_descriptor(self.type) ClauseList.__init__(self, parens=True, *clauses) key = property(lambda self:self.name) def append(self, clause): @@ -771,6 +774,8 @@ class Function(ClauseList, ColumnElement): else: clause = BindParamClause(self.name, clause, shortname=self.name, type=None) self.clauses.append(clause) + def _process_from_dict(self, data, asfrom): + data.setdefault(self, self) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return Function(self.name, type=self.type, *clauses) @@ -782,6 +787,10 @@ class Function(ClauseList, ColumnElement): return BindParamClause(self.name, obj, shortname=self.name, type=self.type) def select(self): return select([self]) + def scalar(self): + return select([self]).scalar() + def execute(self): + return select([self]).execute() def _compare_type(self, obj): return self.type diff --git a/test/defaults.py b/test/defaults.py index 459b3abfe9..c2c8877eb1 100644 --- a/test/defaults.py +++ b/test/defaults.py @@ -7,11 +7,11 @@ from sqlalchemy import * import sqlalchemy db = testbase.db - +testbase.echo=False class DefaultTest(PersistTest): def setUpAll(self): - global t, f, ts + global t, f, ts, currenttime x = {'x':50} def mydefault(): x['x'] += 1 @@ -22,18 +22,19 @@ class DefaultTest(PersistTest): # select "count(1)" from the DB which returns different results # on different DBs + currenttime = db.func.current_date(type=Date); if is_oracle: - f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar() - ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar() - def1 = func.sysdate() + ts = db.func.sysdate().scalar() + f = select([func.count(1) + 5], engine=db).scalar() + def1 = currenttime def2 = text("sysdate") deftype = Date elif use_function_defaults: f = select([func.count(1) + 5], engine=db).scalar() - def1 = func.current_date() + def1 = currenttime def2 = text("current_date") deftype = Date - ts = select([func.current_date()], engine=db).scalar() + ts = db.func.current_date().scalar() else: f = select([func.count(1) + 5], engine=db).scalar() def1 = def2 = "3" @@ -45,20 +46,29 @@ class DefaultTest(PersistTest): Column('col1', Integer, primary_key=True, default=mydefault), # python literal - Column('col2', String(20), default="imthedefault"), + Column('col2', String(20), default="imthedefault", onupdate="im the update"), # preexecute expression - Column('col3', Integer, default=func.count(1) + 5), + Column('col3', Integer, default=func.count(1) + 5, onupdate=func.count(1) + 14), # SQL-side default from sql expression Column('col4', deftype, PassiveDefault(def1)), # SQL-side default from literal expression - Column('col5', deftype, PassiveDefault(def2)) + Column('col5', deftype, PassiveDefault(def2)), + + # preexecute + update timestamp + Column('col6', Date, default=currenttime, onupdate=currenttime) ) t.create() - def teststandalonedefaults(self): + def tearDownAll(self): + t.drop() + + def tearDown(self): + t.delete().execute() + + def teststandalone(self): x = t.c.col1.default.execute() y = t.c.col2.default.execute() z = t.c.col3.default.execute() @@ -66,18 +76,27 @@ class DefaultTest(PersistTest): self.assert_(y == 'imthedefault') self.assert_(z == 6) - def testinsertdefaults(self): + def testinsert(self): t.insert().execute() self.assert_(t.engine.lastrow_has_defaults()) t.insert().execute() t.insert().execute() - - l = t.select().execute() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)]) - def tearDownAll(self): - t.drop() + ctexec = currenttime.scalar() + self.echo("Currenttime "+ repr(ctexec)) + l = t.select().execute() + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)]) + def testupdate(self): + t.insert().execute() + pk = t.engine.last_inserted_ids()[0] + t.update(t.c.col1==pk).execute(col4=None, col5=None) + ctexec = currenttime.scalar() + self.echo("Currenttime "+ repr(ctexec)) + l = t.select(t.c.col1==pk).execute() + l = l.fetchone() + self.assert_(l == (pk, 'im the update', 15, None, None, ctexec)) + class SequenceTest(PersistTest): def setUpAll(self): -- 2.47.2