refactored workings of defaults so that they share the same execution context.
will also autoclose the connection if defined for the operation; this
allows more efficient usage of connections for successive CRUD operations
with less chance of "dangling connections".
+ - Column defaults and onupdate Python functions (i.e. passed to ColumnDefault)
+ may take zero or one arguments; the one argument is the ExecutionContext,
+ from which you can call "context.parameters[someparam]" to access the other
+ bind parameter values affixed to the statement [ticket:559]
- added "explcit" create/drop/execute support for sequences
(i.e. you can pass a "connectable" to each of those methods
on Sequence)
resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
return [row[0] for row in resultset]
- def defaultrunner(self, connection, **kwargs):
- return PGDefaultRunner(connection, **kwargs)
+ def defaultrunner(self, context, **kwargs):
+ return PGDefaultRunner(context, **kwargs)
def preparer(self):
return PGIdentifierPreparer(self)
raise NotImplementedError()
- def defaultrunner(self, connection, **kwargs):
+ def defaultrunner(self, execution_context):
"""Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
- connection
- a [sqlalchemy.engine#Connection] to use for statement execution
+ execution_context
+ a [sqlalchemy.engine#ExecutionContext] to use for statement execution
"""
except AttributeError:
raise exceptions.InvalidRequestError("This Connection is closed")
+ def _branch(self):
+ """return a new Connection which references this Connection's
+ engine and connection; but does not have close_with_result enabled."""
+
+ return Connection(self.__engine, self.__connection)
+
engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
def _execute_default(self, default, multiparams=None, params=None):
- return self.__engine.dialect.defaultrunner(self).traverse_single(default)
+ return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
def _execute_text(self, statement, multiparams, params):
parameters = self.__distill_params(multiparams, params)
DefaultRunner to allow database-specific behavior.
"""
- def __init__(self, connection):
- self.connection = connection
- self.dialect = connection.dialect
+ def __init__(self, context):
+ self.context = context
+ # branch the connection so it doesnt close after result
+ self.connection = context.connection._branch()
+ dialect = property(lambda self:self.context.dialect)
+
def get_column_default(self, column):
if column.default is not None:
return self.traverse_single(column.default)
if isinstance(onupdate.arg, sql.ClauseElement):
return self.exec_default_sql(onupdate)
elif callable(onupdate.arg):
- return onupdate.arg()
+ return onupdate.arg(self.context)
else:
return onupdate.arg
if isinstance(default.arg, sql.ClauseElement):
return self.exec_default_sql(default)
elif callable(default.arg):
- return default.arg()
+ return default.arg(self.context)
else:
return default.arg
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
- def defaultrunner(self, connection):
- return base.DefaultRunner(connection)
+ def defaultrunner(self, context):
+ return base.DefaultRunner(context)
def is_disconnect(self, e):
return False
self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
if len(self.compiled_parameters) == 1:
self.compiled_parameters = self.compiled_parameters[0]
- else:
+ elif statement is not None:
self.typemap = self.column_labels = None
self.parameters = self.__encode_param_keys(parameters)
self.statement = statement
-
- if not dialect.supports_unicode_statements():
+ else:
+ self.statement = None
+
+ if self.statement is not None and not dialect.supports_unicode_statements():
self.statement = self.statement.encode(self.dialect.encoding)
self.cursor = self.create_cursor()
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
last_inserted_ids = []
plist = self.compiled_parameters
else:
plist = [self.compiled_parameters]
- drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+ drunner = self.dialect.defaultrunner(self)
self._lastrow_has_defaults = False
for param in plist:
# check the "onupdate" status of each column in the table
from sqlalchemy import sql, types, exceptions,util, databases
import sqlalchemy
-import re, string
+import re, string, inspect
__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint',
'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint',
def __init__(self, arg, **kwargs):
super(ColumnDefault, self).__init__(**kwargs)
- self.arg = arg
+ if callable(arg):
+ if not inspect.isfunction(arg):
+ self.arg = lambda ctx: arg()
+ else:
+ argspec = inspect.getargspec(arg)
+ if len(argspec[0]) == 0:
+ self.arg = lambda ctx: arg()
+ elif len(argspec[0]) != 1:
+ raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments")
+ else:
+ self.arg = arg
+ else:
+ self.arg = arg
def _visit_name(self):
if self.for_update:
import sqlalchemy.schema as schema
from sqlalchemy.orm import mapper, create_session
from testlib import *
+import datetime
class DefaultTest(PersistTest):
x['x'] += 1
return x['x']
+ def mydefault_with_ctx(ctx):
+ return ctx.compiled_parameters['col1'] + 10
+
+ def myupdate_with_ctx(ctx):
+ return len(ctx.compiled_parameters['col2'])
+
use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
is_oracle = db.engine.name == 'oracle'
Column('col6', Date, default=currenttime, onupdate=currenttime),
Column('boolcol1', Boolean, default=True),
- Column('boolcol2', Boolean, default=False)
+ Column('boolcol2', Boolean, default=False),
+
+ # python function which uses ExecutionContext
+ Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx),
+
+ # python builtin
+ Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today)
)
t.create()
def tearDown(self):
t.delete().execute()
-
+
+ def testargsignature(self):
+ def mydefault(x, y):
+ pass
+ try:
+ c = ColumnDefault(mydefault)
+ assert False
+ except exceptions.ArgumentError, e:
+ assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e)
+
def teststandalone(self):
c = testbase.db.engine.contextual_connect()
x = c.execute(t.c.col1.default)
ctexec = currenttime.scalar()
print "Currenttime "+ repr(ctexec)
l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)])
+ today = datetime.date.today()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)])
def testinsertvalues(self):
t.insert(values={'col3':50}).execute()
print "Currenttime "+ repr(ctexec)
l = t.select(t.c.col1==pk).execute()
l = l.fetchone()
- self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False))
+ self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today()))
# mysql/other db's return 0 or 1 for count(1)
self.assert_(14 <= f2 <= 15)