]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ColumnDefault functions pass ExecutionContext to callables which accept a single...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 20:56:27 +0000 (20:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 20:56:27 +0000 (20:56 +0000)
refactored workings of defaults so that they share the same execution context.

CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
test/sql/defaults.py

diff --git a/CHANGES b/CHANGES
index f68b80c53226b1de5793df48246996dbb3d86dcb..4834e366269b62c1e02101767ea47675c617aa8c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     will also autoclose the connection if defined for the operation; this 
     allows more efficient usage of connections for successive CRUD operations
     with less chance of "dangling connections".
+  - Column defaults and onupdate Python functions (i.e. passed to ColumnDefault)
+    may take zero or one arguments; the one argument is the ExecutionContext,
+    from which you can call "context.parameters[someparam]" to access the other
+    bind parameter values affixed to the statement [ticket:559]
   - added "explcit" create/drop/execute support for sequences 
     (i.e. you can pass a "connectable" to each of those methods
     on Sequence)
index 96ca048b11b1796478b00b07a1933fa2aed05375..d8f467358ffa687ec7f588696c921bc1267d8c83 100644 (file)
@@ -265,8 +265,8 @@ class PGDialect(ansisql.ANSIDialect):
         resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
         return [row[0] for row in resultset]
 
-    def defaultrunner(self, connection, **kwargs):
-        return PGDefaultRunner(connection, **kwargs)
+    def defaultrunner(self, context, **kwargs):
+        return PGDefaultRunner(context, **kwargs)
 
     def preparer(self):
         return PGIdentifierPreparer(self)
index 796df1a5c1ab89dac97469315f12957c4b3ef58e..d2a0d85d7da351ce3fe14eb974ec4788a7e75e96 100644 (file)
@@ -129,11 +129,11 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
-    def defaultrunner(self, connection, **kwargs):
+    def defaultrunner(self, execution_context):
         """Return a [sqlalchemy.schema#SchemaVisitor] instance that can execute defaults.
         
-            connection
-                a [sqlalchemy.engine#Connection] to use for statement execution
+            execution_context
+                a [sqlalchemy.engine#ExecutionContext] to use for statement execution
         
         """
 
@@ -514,6 +514,12 @@ class Connection(Connectable):
         except AttributeError:
             raise exceptions.InvalidRequestError("This Connection is closed")
 
+    def _branch(self):
+        """return a new Connection which references this Connection's 
+        engine and connection; but does not have close_with_result enabled."""
+        
+        return Connection(self.__engine, self.__connection)
+        
     engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.")
     dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.")
     connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
@@ -694,7 +700,7 @@ class Connection(Connectable):
             raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
 
     def _execute_default(self, default, multiparams=None, params=None):
-        return self.__engine.dialect.defaultrunner(self).traverse_single(default)
+        return self.__engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
 
     def _execute_text(self, statement, multiparams, params):
         parameters = self.__distill_params(multiparams, params)
@@ -1461,10 +1467,13 @@ class DefaultRunner(schema.SchemaVisitor):
     DefaultRunner to allow database-specific behavior.
     """
 
-    def __init__(self, connection):
-        self.connection = connection
-        self.dialect = connection.dialect
+    def __init__(self, context):
+        self.context = context
+        # branch the connection so it doesnt close after result
+        self.connection = context.connection._branch()
         
+    dialect = property(lambda self:self.context.dialect)
+    
     def get_column_default(self, column):
         if column.default is not None:
             return self.traverse_single(column.default)
@@ -1502,7 +1511,7 @@ class DefaultRunner(schema.SchemaVisitor):
         if isinstance(onupdate.arg, sql.ClauseElement):
             return self.exec_default_sql(onupdate)
         elif callable(onupdate.arg):
-            return onupdate.arg()
+            return onupdate.arg(self.context)
         else:
             return onupdate.arg
 
@@ -1510,6 +1519,6 @@ class DefaultRunner(schema.SchemaVisitor):
         if isinstance(default.arg, sql.ClauseElement):
             return self.exec_default_sql(default)
         elif callable(default.arg):
-            return default.arg()
+            return default.arg(self.context)
         else:
             return default.arg
index dfdc1baaa495f619595025dba1809a4832a8baa6..b529b4672214d2c209b834b22091dafa83cdd7e8 100644 (file)
@@ -115,8 +115,8 @@ class DefaultDialect(base.Dialect):
     def do_execute(self, cursor, statement, parameters, **kwargs):
         cursor.execute(statement, parameters)
 
-    def defaultrunner(self, connection):
-        return base.DefaultRunner(connection)
+    def defaultrunner(self, context):
+        return base.DefaultRunner(context)
 
     def is_disconnect(self, e):
         return False
@@ -172,12 +172,14 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
                 if len(self.compiled_parameters) == 1:
                     self.compiled_parameters = self.compiled_parameters[0]
-        else:
+        elif statement is not None:
             self.typemap = self.column_labels = None
             self.parameters = self.__encode_param_keys(parameters)
             self.statement = statement
-
-        if not dialect.supports_unicode_statements():
+        else:
+            self.statement = None
+            
+        if self.statement is not None and not dialect.supports_unicode_statements():
             self.statement = self.statement.encode(self.dialect.encoding)
             
         self.cursor = self.create_cursor()
@@ -306,7 +308,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                 plist = self.compiled_parameters
             else:
                 plist = [self.compiled_parameters]
-            drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+            drunner = self.dialect.defaultrunner(self)
             self._lastrow_has_defaults = False
             for param in plist:
                 last_inserted_ids = []
@@ -346,7 +348,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                 plist = self.compiled_parameters
             else:
                 plist = [self.compiled_parameters]
-            drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection))
+            drunner = self.dialect.defaultrunner(self)
             self._lastrow_has_defaults = False
             for param in plist:
                 # check the "onupdate" status of each column in the table
index 7a278053747fb3c933b96b2efeee2de65d6244d0..00b9cff68c09d9cb48d1958d65781a841c891776 100644 (file)
@@ -19,7 +19,7 @@ objects as well as the visitor interface, so that the schema package
 
 from sqlalchemy import sql, types, exceptions,util, databases
 import sqlalchemy
-import re, string
+import re, string, inspect
 
 __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint',
             'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'DefaultGenerator', 'Constraint',
@@ -802,7 +802,19 @@ class ColumnDefault(DefaultGenerator):
 
     def __init__(self, arg, **kwargs):
         super(ColumnDefault, self).__init__(**kwargs)
-        self.arg = arg
+        if callable(arg):
+            if not inspect.isfunction(arg):
+                self.arg = lambda ctx: arg()
+            else:
+                argspec = inspect.getargspec(arg)
+                if len(argspec[0]) == 0:
+                    self.arg = lambda ctx: arg()
+                elif len(argspec[0]) != 1:
+                    raise exceptions.ArgumentError("ColumnDefault Python function takes zero or one positional arguments")
+                else:
+                    self.arg = arg
+        else:
+            self.arg = arg
 
     def _visit_name(self):
         if self.for_update:
index 5cbdc3e3fb3151fdc1579f53817968b7b4bfe126..6c200232f229708cbf1cfb178e8e6007b52126c8 100644 (file)
@@ -4,6 +4,7 @@ import sqlalchemy.util as util
 import sqlalchemy.schema as schema
 from sqlalchemy.orm import mapper, create_session
 from testlib import *
+import datetime
 
 class DefaultTest(PersistTest):
 
@@ -17,6 +18,12 @@ class DefaultTest(PersistTest):
             x['x'] += 1
             return x['x']
 
+        def mydefault_with_ctx(ctx):
+            return ctx.compiled_parameters['col1'] + 10
+
+        def myupdate_with_ctx(ctx):
+            return len(ctx.compiled_parameters['col2'])
+            
         use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
         is_oracle = db.engine.name == 'oracle'
  
@@ -66,7 +73,13 @@ class DefaultTest(PersistTest):
             Column('col6', Date, default=currenttime, onupdate=currenttime),
             
             Column('boolcol1', Boolean, default=True),
-            Column('boolcol2', Boolean, default=False)
+            Column('boolcol2', Boolean, default=False),
+            
+            # python function which uses ExecutionContext
+            Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx),
+            
+            # python builtin
+            Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today)
         )
         t.create()
 
@@ -75,7 +88,16 @@ class DefaultTest(PersistTest):
     
     def tearDown(self):
         t.delete().execute()
-        
+    
+    def testargsignature(self):
+        def mydefault(x, y):
+            pass
+        try:
+            c = ColumnDefault(mydefault)
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "ColumnDefault Python function takes zero or one positional arguments", str(e)
+            
     def teststandalone(self):
         c = testbase.db.engine.contextual_connect()
         x = c.execute(t.c.col1.default)
@@ -96,7 +118,8 @@ class DefaultTest(PersistTest):
         ctexec = currenttime.scalar()
         print "Currenttime "+ repr(ctexec)
         l = t.select().execute()
-        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)])
+        today = datetime.date.today()
+        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)])
 
     def testinsertvalues(self):
         t.insert(values={'col3':50}).execute()
@@ -112,7 +135,7 @@ class DefaultTest(PersistTest):
         print "Currenttime "+ repr(ctexec)
         l = t.select(t.c.col1==pk).execute()
         l = l.fetchone()
-        self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False))
+        self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today()))
         # mysql/other db's return 0 or 1 for count(1)
         self.assert_(14 <= f2 <= 15)