From: Mike Bayer Date: Mon, 23 Jul 2007 20:56:27 +0000 (+0000) Subject: ColumnDefault functions pass ExecutionContext to callables which accept a single... X-Git-Tag: rel_0_4_6~41 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0a4d9f9a8f6a01bae64c0216740e9d52548d7cd9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ColumnDefault functions pass ExecutionContext to callables which accept a single argument; refactored workings of defaults so that they share the same execution context. --- diff --git a/CHANGES b/CHANGES index f68b80c532..4834e36626 100644 --- a/CHANGES +++ b/CHANGES @@ -163,6 +163,10 @@ will also autoclose the connection if defined for the operation; this allows more efficient usage of connections for successive CRUD operations with less chance of "dangling connections". + - Column defaults and onupdate Python functions (i.e. passed to ColumnDefault) + may take zero or one arguments; the one argument is the ExecutionContext, + from which you can call "context.parameters[someparam]" to access the other + bind parameter values affixed to the statement [ticket:559] - added "explcit" create/drop/execute support for sequences (i.e. you can pass a "connectable" to each of those methods on Sequence) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 96ca048b11..d8f467358f 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -265,8 +265,8 @@ class PGDialect(ansisql.ANSIDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] - def defaultrunner(self, connection, **kwargs): - return PGDefaultRunner(connection, **kwargs) + def defaultrunner(self, context, **kwargs): + return PGDefaultRunner(context, **kwargs) def preparer(self): return PGIdentifierPreparer(self) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 796df1a5c1..d2a0d85d7d 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -129,11 +129,11 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def defaultrunner(self, connection, **kwargs): + def defaultrunner(self, execution_context): """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults. - connection - a [sqlalchemy.engine#Connection] to use for statement execution + execution_context + a [sqlalchemy.engine#ExecutionContext] to use for statement execution """ @@ -514,6 +514,12 @@ class Connection(Connectable): except AttributeError: raise exceptions.InvalidRequestError("This Connection is closed") + def _branch(self): + """return a new Connection which references this Connection's + engine and connection; but does not have close_with_result enabled.""" + + return Connection(self.__engine, self.__connection) + engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") @@ -694,7 +700,7 @@ class Connection(Connectable): raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) def _execute_default(self, default, multiparams=None, params=None): - return self.__engine.dialect.defaultrunner(self).traverse_single(default) + return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) def _execute_text(self, statement, multiparams, params): parameters = self.__distill_params(multiparams, params) @@ -1461,10 +1467,13 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunner to allow database-specific behavior. """ - def __init__(self, connection): - self.connection = connection - self.dialect = connection.dialect + def __init__(self, context): + self.context = context + # branch the connection so it doesnt close after result + self.connection = context.connection._branch() + dialect = property(lambda self:self.context.dialect) + def get_column_default(self, column): if column.default is not None: return self.traverse_single(column.default) @@ -1502,7 +1511,7 @@ class DefaultRunner(schema.SchemaVisitor): if isinstance(onupdate.arg, sql.ClauseElement): return self.exec_default_sql(onupdate) elif callable(onupdate.arg): - return onupdate.arg() + return onupdate.arg(self.context) else: return onupdate.arg @@ -1510,6 +1519,6 @@ class DefaultRunner(schema.SchemaVisitor): if isinstance(default.arg, sql.ClauseElement): return self.exec_default_sql(default) elif callable(default.arg): - return default.arg() + return default.arg(self.context) else: return default.arg diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index dfdc1baaa4..b529b46722 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -115,8 +115,8 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, connection): - return base.DefaultRunner(connection) + def defaultrunner(self, context): + return base.DefaultRunner(context) def is_disconnect(self, e): return False @@ -172,12 +172,14 @@ class DefaultExecutionContext(base.ExecutionContext): self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters] if len(self.compiled_parameters) == 1: self.compiled_parameters = self.compiled_parameters[0] - else: + elif statement is not None: self.typemap = self.column_labels = None self.parameters = self.__encode_param_keys(parameters) self.statement = statement - - if not dialect.supports_unicode_statements(): + else: + self.statement = None + + if self.statement is not None and not dialect.supports_unicode_statements(): self.statement = self.statement.encode(self.dialect.encoding) self.cursor = self.create_cursor() @@ -306,7 +308,7 @@ class DefaultExecutionContext(base.ExecutionContext): plist = self.compiled_parameters else: plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) + drunner = self.dialect.defaultrunner(self) self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] @@ -346,7 +348,7 @@ class DefaultExecutionContext(base.ExecutionContext): plist = self.compiled_parameters else: plist = [self.compiled_parameters] - drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) + drunner = self.dialect.defaultrunner(self) self._lastrow_has_defaults = False for param in plist: # check the "onupdate" status of each column in the table diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 7a27805374..00b9cff68c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,7 +19,7 @@ objects as well as the visitor interface, so that the schema package from sqlalchemy import sql, types, exceptions,util, databases import sqlalchemy -import re, string +import re, string, inspect __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint', @@ -802,7 +802,19 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) - self.arg = arg + if callable(arg): + if not inspect.isfunction(arg): + self.arg = lambda ctx: arg() + else: + argspec = inspect.getargspec(arg) + if len(argspec[0]) == 0: + self.arg = lambda ctx: arg() + elif len(argspec[0]) != 1: + raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments") + else: + self.arg = arg + else: + self.arg = arg def _visit_name(self): if self.for_update: diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 5cbdc3e3fb..6c200232f2 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -4,6 +4,7 @@ import sqlalchemy.util as util import sqlalchemy.schema as schema from sqlalchemy.orm import mapper, create_session from testlib import * +import datetime class DefaultTest(PersistTest): @@ -17,6 +18,12 @@ class DefaultTest(PersistTest): x['x'] += 1 return x['x'] + def mydefault_with_ctx(ctx): + return ctx.compiled_parameters['col1'] + 10 + + def myupdate_with_ctx(ctx): + return len(ctx.compiled_parameters['col2']) + use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' is_oracle = db.engine.name == 'oracle' @@ -66,7 +73,13 @@ class DefaultTest(PersistTest): Column('col6', Date, default=currenttime, onupdate=currenttime), Column('boolcol1', Boolean, default=True), - Column('boolcol2', Boolean, default=False) + Column('boolcol2', Boolean, default=False), + + # python function which uses ExecutionContext + Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx), + + # python builtin + Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today) ) t.create() @@ -75,7 +88,16 @@ class DefaultTest(PersistTest): def tearDown(self): t.delete().execute() - + + def testargsignature(self): + def mydefault(x, y): + pass + try: + c = ColumnDefault(mydefault) + assert False + except exceptions.ArgumentError, e: + assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e) + def teststandalone(self): c = testbase.db.engine.contextual_connect() x = c.execute(t.c.col1.default) @@ -96,7 +118,8 @@ class DefaultTest(PersistTest): ctexec = currenttime.scalar() print "Currenttime "+ repr(ctexec) l = t.select().execute() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)]) + today = datetime.date.today() + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)]) def testinsertvalues(self): t.insert(values={'col3':50}).execute() @@ -112,7 +135,7 @@ class DefaultTest(PersistTest): print "Currenttime "+ repr(ctexec) l = t.select(t.c.col1==pk).execute() l = l.fetchone() - self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False)) + self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today())) # mysql/other db's return 0 or 1 for count(1) self.assert_(14 <= f2 <= 15)