]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactoring of execution path, defaults, and treatment of different paramstyles
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Dec 2005 02:49:47 +0000 (02:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Dec 2005 02:49:47 +0000 (02:49 +0000)
12 files changed:
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/query.py
test/select.py
test/tables.py
test/testbase.py

index cd1d3a0b0a6e74b09da40562a5b2c49271ee4745..e4bcdd077564cabd4c3bf37b886dd9adb1ef6a1e 100644 (file)
@@ -37,8 +37,8 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
     def schemadropper(self, proxy, **params):
         return ANSISchemaDropper(proxy, **params)
 
-    def compiler(self, statement, bindparams, **kwargs):
-        return ANSICompiler(self, statement, bindparams, **kwargs)
+    def compiler(self, statement, parameters, **kwargs):
+        return ANSICompiler(self, statement, parameters, **kwargs)
     
     def connect_args(self):
         return ([],{})
@@ -47,8 +47,20 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
         return None
 
 class ANSICompiler(sql.Compiled):
-    def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs):
-        sql.Compiled.__init__(self, engine, statement, bindparams)
+    """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
+    def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs):
+        """constructs a new ANSICompiler object.
+        
+        engine - SQLEngine to compile against
+        
+        statement - ClauseElement to be compiled
+        
+        parameters - optional dictionary indicating a set of bind parameters
+        specified with this Compiled object.  These parameters are the "default"
+        key/value pairs when the Compiled is executed, and also may affect the 
+        actual compilation, as in the case of an INSERT where the actual columns
+        inserted will correspond to the keys present in the parameters."""
+        sql.Compiled.__init__(self, engine, statement, parameters)
         self.binds = {}
         self.froms = {}
         self.wheres = {}
@@ -57,37 +69,18 @@ class ANSICompiler(sql.Compiled):
         self.typemap = typemap or {}
         self.isinsert = False
         
-        if paramstyle is None:
-            db = self.engine.dbapi()
-            if db is not None:
-                paramstyle = db.paramstyle
-            else:
-                paramstyle = 'named'
-
-        if paramstyle == 'named':
-            self.bindtemplate = ':%s'
-            self.positional=False
-        elif paramstyle =='pyformat':
-            self.bindtemplate = "%%(%s)s"
-            self.positional=False
-        else:
-            # for positional, use pyformat until the end
-            self.bindtemplate = "%%(%s)s"
-            self.positional=True
-        self.paramstyle=paramstyle
-        
     def after_compile(self):
-        if self.positional:
+        if self.engine.positional:
             self.positiontup = []
             match = r'%\(([\w_]+)\)s'
             params = re.finditer(match, self.strings[self.statement])
             for p in params:
                 self.positiontup.append(p.group(1))
-            if self.paramstyle=='qmark':
+            if self.engine.paramstyle=='qmark':
                 self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
-            elif self.paramstyle=='format':
+            elif self.engine.paramstyle=='format':
                 self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
-            elif self.paramstyle=='numeric':
+            elif self.engine.paramstyle=='numeric':
                 i = 0
                 def getnum(x):
                     i += 1
@@ -116,14 +109,22 @@ class ANSICompiler(sql.Compiled):
         for an executemany style of call, this method should be called for each element
         in the list of parameter groups that will ultimately be executed.
         """
-        d = {}
-        if self.bindparams is not None:
-            bindparams = self.bindparams.copy()
+        if self.parameters is not None:
+            bindparams = self.parameters.copy()
         else:
             bindparams = {}
         bindparams.update(params)
-        # TODO: cant we make "d" an ordereddict and add params in 
-        # positional order
+
+        if self.engine.positional:
+            d = OrderedDict()
+            for k in self.positiontup:
+                b = self.binds[k]
+                d[k] = b.typeprocess(b.value)
+        else:
+            d = {}
+            for b in self.binds.values():
+                d[b.key] = b.typeprocess(b.value)
+            
         for key, value in bindparams.iteritems():
             try:
                 b = self.binds[key]
@@ -131,11 +132,9 @@ class ANSICompiler(sql.Compiled):
                 continue
             d[b.key] = b.typeprocess(value)
 
-        for b in self.binds.values():
-            d.setdefault(b.key, b.typeprocess(b.value))
-
-        if self.positional:
-            return [d[key] for key in self.positiontup]
+        return d
+        if self.engine.positional:
+            return d.values()
         else:
             return d
 
@@ -145,7 +144,8 @@ class ANSICompiler(sql.Compiled):
         same dictionary.  For a positional paramstyle, the given parameters are
         assumed to be in list format and are converted back to a dictionary.
         """
-        if self.positional:
+#        return parameters
+        if self.engine.positional:
             p = {}
             for i in range(0, len(self.positiontup)):
                 p[self.positiontup[i]] = parameters[i]
@@ -237,7 +237,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[bindparam] = self.bindparam_string(key)
 
     def bindparam_string(self, name):
-        return self.bindtemplate % name
+        return self.engine.bindtemplate % name
         
     def visit_alias(self, alias):
         self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name
@@ -265,7 +265,7 @@ class ANSICompiler(sql.Compiled):
         text = "SELECT "
         if select.distinct:
             text += "DISTINCT "
-        text += collist + " \nFROM "
+        text += collist
         
         whereclause = select.whereclause
         
@@ -282,8 +282,10 @@ class ANSICompiler(sql.Compiled):
             t = self.get_from_text(f)
             if t is not None:
                 froms.append(t)
-
-        text += string.join(froms, ', ')
+        
+        if len(froms):
+            text += " \nFROM "
+            text += string.join(froms, ', ')
 
         if whereclause is not None:
             t = self.get_str(whereclause)
@@ -333,10 +335,31 @@ class ANSICompiler(sql.Compiled):
             self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
             " ON " + self.get_str(join.onclause))
         self.strings[join] = self.froms[join]
+
+    def visit_insert_column_default(self, column, default):
+        """called when visiting an Insert statement, for each column in the table that
+        contains a ColumnDefault object."""
+        self.parameters.setdefault(column.key, None)
+        
+    def visit_insert_sequence(self, column, sequence):
+        """called when visiting an Insert statement, for each column in the table that
+        contains a Sequence object."""
+        pass
         
     def visit_insert(self, insert_stmt):
+        # set up a call for the defaults and sequences inside the table
+        class DefaultVisitor(schema.SchemaVisitor):
+            def visit_column_default(s, cd):
+                self.visit_insert_column_default(c, cd)
+            def visit_sequence(s, seq):
+                self.visit_insert_sequence(c, seq)
+        vis = DefaultVisitor()
+        for c in insert_stmt.table.c:
+            if self.parameters.get(c.key, None) is None and c.default is not None:
+                c.default.accept_visitor(vis)
+        
         self.isinsert = True
-        colparams = insert_stmt.get_colparams(self.bindparams)
+        colparams = insert_stmt.get_colparams(self.parameters)
         for c in colparams:
             b = c[1]
             self.binds[b.key] = b
@@ -348,7 +371,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[insert_stmt] = text
 
     def visit_update(self, update_stmt):
-        colparams = update_stmt.get_colparams(self.bindparams)
+        colparams = update_stmt.get_colparams(self.parameters)
         def create_param(p):
             if isinstance(p, sql.BindParamClause):
                 self.binds[p.key] = p
index 191a57ba60b3c961ff703dba62027e1700890a65..96bacf2101557d348ed63988c335f546b49daf83 100644 (file)
@@ -140,8 +140,7 @@ class MySQLEngine(ansisql.ANSISQLEngine):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
             
-    def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
-        if compiled is None: return
+    def post_exec(self, proxy, compiled, parameters, **kwargs):
         if getattr(compiled, "isinsert", False):
             self.context.last_inserted_ids = [proxy().lastrowid]
     
index bc30a4937dff704ff64b65c18bb314d15aeb4339..163d387bc47760809e512663a25296198a131d6c 100644 (file)
@@ -118,8 +118,7 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
-    def pre_exec(self, proxy, statement, parameters, compiled=None, **kwargs):
-        if compiled is None: return
+    def pre_exec(self, proxy, compiled, parameters, **kwargs):
         # this is just an assertion that all the primary key columns in an insert statement
         # have a value set up, or have a default generator ready to go
         if getattr(compiled, "isinsert", False):
index 29590da0acfc1754fd8451a9ebe3d73dcba7eb8d..0ec84dec4d8b0f52c8b5fa8f489f1beb3ce05756 100644 (file)
@@ -153,7 +153,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         return PGSchemaDropper(proxy, **params)
 
     def defaultrunner(self, proxy):
-        return PGDefaultRunner(proxy)
+        return PGDefaultRunner(self, proxy)
         
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
@@ -166,8 +166,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
     def pre_exec(self, proxy, statement, parameters, **kwargs):
         return
 
-    def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
-        if compiled is None: return
+    def post_exec(self, proxy, compiled, parameters, **kwargs):
         if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
             table = compiled.statement.table
             cursor = proxy()
@@ -200,15 +199,10 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         ischema.reflecttable(self, table, ischema_names)
 
 class PGCompiler(ansisql.ANSICompiler):
-    def visit_insert(self, insert):
-        """inserts are required to have the primary keys be explicitly present.
-         mapper will by default not put them in the insert statement to comply
-         with autoincrement fields that require they not be present.  so, 
-         put them all in for columns where sequence usage is defined."""
-        for c in insert.table.primary_key:
-            if self.bindparams.get(c.key, None) is None and c.default is not None and not c.default.optional:
-                self.bindparams[c.key] = None
-        return ansisql.ANSICompiler.visit_insert(self, insert)
+
+    def visit_insert_sequence(self, column, sequence):
+        if self.parameters.get(column.key, None) is None and not sequence.optional:
+            self.parameters[column.key] = None
 
     def limit_clause(self, select):
         text = ""
@@ -223,7 +217,7 @@ class PGCompiler(ansisql.ANSICompiler):
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
-        if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or column.default.optional):
+        if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
         else:
             colspec += " " + column.type.get_col_spec()
index a70f65d2a201ed1da70d50a1478c9ac9587445d8..e743d14e0bc83dd00a85de8deffc54969e249d51 100644 (file)
@@ -108,8 +108,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
         params['poolclass'] = sqlalchemy.pool.SingletonThreadPool
         ansisql.ANSISQLEngine.__init__(self, **params)
 
-    def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
-        if compiled is None: return
+    def post_exec(self, proxy, compiled, parameters, **kwargs):
         if getattr(compiled, "isinsert", False):
             self.context.last_inserted_ids = [proxy().lastrowid]
 
index 22ccbd11c010029526435cc30d34190dfe512a25..81d72b17b2da65af7b462457b58ed4b8ab39f198 100644 (file)
@@ -78,17 +78,21 @@ class SchemaIterator(schema.SchemaVisitor):
             self.buffer.truncate(0)
 
 class DefaultRunner(schema.SchemaVisitor):
-    def __init__(self, proxy):
+    def __init__(self, engine, proxy):
         self.proxy = proxy
+        self.engine = engine
 
     def visit_sequence(self, seq):
         """sequences are not supported by default"""
         return None
 
+    def exec_default_sql(self, default):
+        c = sql.select([default.arg], engine=self.engine).compile()
+        return self.proxy(str(c), c.get_params()).fetchone()[0]
+        
     def visit_column_default(self, default):
-        if isinstance(default.arg, ClauseElement):
-            c = default.arg.compile()
-            return proxy.execute(str(c), c.get_params())
+        if isinstance(default.arg, sql.ClauseElement):
+            return self.exec_default_sql(default)
         elif callable(default.arg):
             return default.arg()
         else:
@@ -113,11 +117,29 @@ class SQLEngine(schema.SchemaEngine):
         self.context = util.ThreadLocal(raiseerror=False)
         self.tables = {}
         self.notes = {}
+        self._figure_paramstyle()
         if logger is None:
             self.logger = sys.stdout
         else:
             self.logger = logger
-
+    
+    def _figure_paramstyle(self):
+        db = self.dbapi()
+        if db is not None:
+            self.paramstyle = db.paramstyle
+        else:
+            self.paramstyle = 'named'
+
+        if self.paramstyle == 'named':
+            self.bindtemplate = ':%s'
+            self.positional=False
+        elif self.paramstyle =='pyformat':
+            self.bindtemplate = "%%(%s)s"
+            self.positional=False
+        else:
+            # for positional, use pyformat until the end
+            self.bindtemplate = "%%(%s)s"
+            self.positional=True
         
     def type_descriptor(self, typeobj):
         if type(typeobj) is type:
@@ -131,9 +153,9 @@ class SQLEngine(schema.SchemaEngine):
         raise NotImplementedError()
 
     def defaultrunner(self, proxy):
-        return DefaultRunner(proxy)
+        return DefaultRunner(self, proxy)
         
-    def compiler(self, statement, bindparams):
+    def compiler(self, statement, parameters):
         raise NotImplementedError()
 
     def rowid_column_name(self):
@@ -152,11 +174,11 @@ class SQLEngine(schema.SchemaEngine):
         """drops a table given a schema.Table object."""
         table.accept_visitor(self.schemadropper(self.proxy(), **params))
 
-    def compile(self, statement, bindparams, **kwargs):
+    def compile(self, statement, parameters, **kwargs):
         """given a sql.ClauseElement statement plus optional bind parameters, creates a new
         instance of this engine's SQLCompiler, compiles the ClauseElement, and returns the
         newly compiled object."""
-        compiler = self.compiler(statement, bindparams, **kwargs)
+        compiler = self.compiler(statement, parameters, **kwargs)
         statement.accept_visitor(compiler)
         compiler.after_compile()
         return compiler
@@ -263,26 +285,15 @@ class SQLEngine(schema.SchemaEngine):
                 self.context.transaction = None
                 self.context.tcount = None
 
-    def _process_defaults(self, proxy, statement, parameters, compiled=None, **kwargs):
+    def _process_defaults(self, proxy, compiled, parameters, **kwargs):
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
-            # TODO: this sucks.  we have to get "parameters" to be a more standardized object
-            if isinstance(parameters, list) and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
+            if isinstance(parameters, list):
                 plist = parameters
             else:
                 plist = [parameters]
-            # inserts are usually one at a time.  but if we got a list of parameters,
-            # it will calculate last_inserted_ids for just the last row in the list. 
-            # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
-            # it or post-select anyway   
             drunner = self.defaultrunner(proxy)
             for param in plist:
-                # the parameters might be positional, so convert them 
-                # to a dictionary
-                # TODO: this is stupid.  or, is this stupid ?  
-                # any way we can just have an OrderedDict so we have the
-                # dictionary + postional version each time ?
-                param = compiled.get_named_params(param)
                 last_inserted_ids = []
                 need_lastrowid=False
                 for c in compiled.statement.table.c:
@@ -306,18 +317,18 @@ class SQLEngine(schema.SchemaEngine):
                     self.context.last_inserted_ids = last_inserted_ids
 
 
-    def pre_exec(self, proxy, statement, parameters, **kwargs):
+    def pre_exec(self, proxy, compiled, parameters, **kwargs):
         pass
 
-    def post_exec(self, proxy, statement, parameters, **kwargs):
+    def post_exec(self, proxy, compiled, parameters, **kwargs):
         pass
 
-    def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs):
+    def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs):
         """executes the given string-based SQL statement with the given parameters.  
 
         The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
         on the paramstyle of the DBAPI.
-        
+
         If the current thread has specified a transaction begin() for this engine, the
         statement will be executed in the context of the current transactional connection.
         Otherwise, a commit() will be performed immediately after execution, since the local
@@ -352,6 +363,62 @@ class SQLEngine(schema.SchemaEngine):
         def proxy(statement=None, parameters=None):
             if statement is None:
                 return cursor
+            
+            executemany = parameters is not None and isinstance(parameters, list)
+
+            if self.positional:
+                if executemany:
+                    parameters = [p.values() for p in parameters]
+                else:
+                    parameters = parameters.values()
+            
+            self.execute(statement, parameters, connection=connection, cursor=cursor)        
+            return cursor
+
+        self.pre_exec(proxy, compiled, parameters, **kwargs)
+        self._process_defaults(proxy, compiled, parameters, **kwargs)
+        proxy(str(compiled), parameters)
+        self.post_exec(proxy, compiled, parameters, **kwargs)
+        return ResultProxy(cursor, self, typemap=compiled.typemap)
+
+    def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs):
+        """executes the given string-based SQL statement with the given parameters.  
+
+        The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
+        on the paramstyle of the DBAPI.
+        
+        If the current thread has specified a transaction begin() for this engine, the
+        statement will be executed in the context of the current transactional connection.
+        Otherwise, a commit() will be performed immediately after execution, since the local
+        pooled connection is returned to the pool after execution without a transaction set
+        up.
+
+        In all error cases, a rollback() is immediately performed on the connection before
+        propigating the exception outwards.
+
+        Other options include:
+
+        connection  -  a DBAPI connection to use for the execute.  If None, a connection is
+                       pulled from this engine's connection pool.
+
+        echo        -  enables echo for this execution, which causes all SQL and parameters
+                       to be dumped to the engine's logging output before execution.
+
+        typemap     -  a map of column names mapped to sqlalchemy.types.TypeEngine objects.
+                       These will be passed to the created ResultProxy to perform
+                       post-processing on result-set values.
+
+        commit      -  if True, will automatically commit the statement after completion. """
+        if parameters is None:
+            parameters = {}
+
+        if connection is None:
+            connection = self.connection()
+
+        if cursor is None:
+            cursor = connection.cursor()
+
+        try:
             if echo is True or self.echo is not False:
                 self.log(statement)
                 self.log(repr(parameters))
@@ -359,18 +426,10 @@ class SQLEngine(schema.SchemaEngine):
                 self._executemany(cursor, statement, parameters)
             else:
                 self._execute(cursor, statement, parameters)
-            return cursor
-            
-        try:
-            self.pre_exec(proxy, statement, parameters, **kwargs)
-            self._process_defaults(proxy, statement, parameters, **kwargs)
-            proxy(statement, parameters)
-            self.post_exec(proxy, statement, parameters, **kwargs)
-            if commit or self.context.transaction is None:
+            if self.context.transaction is None:
                 self.do_commit(connection)
         except:
             self.do_rollback(connection)
-            # TODO: wrap DB exceptions ?
             raise
         return ResultProxy(cursor, self, typemap = typemap)
 
index 606bcf508de598f2867d65dc2bafbbc266520a7b..13ed33e8286f5764befa4097e78b1158df0730c1 100644 (file)
@@ -187,6 +187,7 @@ class Column(SchemaItem):
         self._impl = self.table.engine.columnimpl(self)
 
         if self.default is not None:
+            self.default = ColumnDefault(self.default)
             self._init_items(self.default)
         self._init_items(*self.args)
         self.args = None
index 86412c2dbd44670ad5e1b2bfb5028938fedb9fac..99755ae6c98cc6fc8080cb3f9e016916fea78549 100644 (file)
@@ -234,17 +234,37 @@ class Compiled(ClauseVisitor):
     object be dependent on the actual values of those bind parameters, even though it may
     reference those values as defaults."""
 
-    def __init__(self, engine, statement, bindparams):
+    def __init__(self, engine, statement, parameters):
+        """constructs a new Compiled object.
+        
+        engine - SQLEngine to compile against
+        
+        statement - ClauseElement to be compiled
+        
+        parameters - optional dictionary indicating a set of bind parameters
+        specified with this Compiled object.  These parameters are the "default"
+        values corresponding to the ClauseElement's BindParamClauses when the Compiled 
+        is executed.   In the case of an INSERT or UPDATE statement, these parameters 
+        will also result in the creation of new BindParamClause objects for each key
+        and will also affect the generated column list in an INSERT statement and the SET 
+        clauses of an UPDATE statement.  The keys of the parameter dictionary can
+        either be the string names of columns or actual sqlalchemy.schema.Column objects."""
         self.engine = engine
-        self.bindparams = bindparams
+        self.parameters = parameters
         self.statement = statement
 
     def __str__(self):
         """returns the string text of the generated SQL statement."""
         raise NotImplementedError()
     def get_params(self, **params):
-        """returns the bind params for this compiled object, with values overridden by 
-        those given in the **params dictionary"""
+        """returns the bind params for this compiled object.
+        
+        Will start with the default parameters specified when this Compiled object
+        was first constructed, and will override those values with those sent via
+        **params, which are key/value pairs.  Each key should match one of the 
+        BindParamClause objects compiled into this object; either the "key" or 
+        "shortname" property of the BindParamClause.
+        """
         raise NotImplementedError()
 
     def execute(self, *multiparams, **params):
@@ -254,7 +274,7 @@ class Compiled(ClauseVisitor):
         else:
             params = self.get_params(**params)
 
-        return self.engine.execute(str(self), params, compiled=self, typemap=self.typemap)
+        return self.engine.execute_compiled(self, params)
 
     def scalar(self, *multiparams, **params):
         """executes this compiled object via the execute() method, then 
@@ -326,7 +346,7 @@ class ClauseElement(object):
             return [self]
     columns = property(lambda s: s._get_columns())
     
-    def compile(self, engine = None, bindparams = None, typemap=None):
+    def compile(self, engine = None, parameters = None, typemap=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
         a Compiled object.  If no engine can be found, an ansisql engine is used.
         bindparams is a dictionary representing the default bind parameters to be used with 
@@ -337,7 +357,7 @@ class ClauseElement(object):
         if engine is None:
             raise "no SQLEngine could be located within this ClauseElement."
 
-        return engine.compile(self, bindparams = bindparams, typemap=typemap)
+        return engine.compile(self, parameters=parameters, typemap=typemap)
 
     def __str__(self):
         e = self.engine
@@ -355,7 +375,7 @@ class ClauseElement(object):
             bindparams = multiparams[0]
         else:
             bindparams = params
-        c = self.compile(e, bindparams = bindparams)
+        c = self.compile(e, parameters=bindparams)
         return c.execute(*multiparams, **params)
 
     def scalar(self, *multiparams, **params):
index 8fc9694f457225ccd42fbc618bad898dcfd2db5e..75088da57f8e93762db4f9a602a999179ee4b840 100644 (file)
@@ -57,6 +57,23 @@ class QueryTest(PersistTest):
            print repr(users_with_date.select().execute().fetchall())
            users_with_date.drop()
 
+    def testdefaults(self):
+        x = {'x':0}
+        def mydefault():
+            x['x'] += 1
+            return x['x']
+            
+        t = Table('default_test1', db, 
+            Column('col1', Integer, primary_key=True, default=mydefault),
+            Column('col2', String(20), default="imthedefault"),
+            Column('col3', String(20), default=func.count(1)),
+        )
+        t.create()
+        t.insert().execute()
+        t.insert().execute()
+        t.insert().execute()
+        t.drop()
+        
     def testdelete(self):
         c = db.connection()
 
index b144f8804a3bd249210e753cb3659c5c9c1171ec..8a0027beaa6913aa495df6d413b3d4613bf60081 100644 (file)
@@ -357,7 +357,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
         # check that the bind params sent along with a compile() call
         # get preserved when the params are retreived later
         s = select([table], table.c.id == bindparam('test'))
-        c = s.compile(bindparams = {'test' : 7})
+        c = s.compile(parameters = {'test' : 7})
         self.assert_(c.get_params() == {'test' : 7})
 
     def testcorrelatedsubquery(self):
@@ -425,7 +425,7 @@ class CRUDTest(SQLTest):
         self.runtest(update(table, table.c.id == 12, values = {table.c.id : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
         s = table.update(table.c.id == 12, values = {table.c.name : 'lala'})
         print str(s)
-        c = s.compile(bindparams = {'mytable_id':9,'name':'h0h0'})
+        c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'})
         print str(c)
         self.assert_(str(s) == str(c))
         
index d0d06924888758f769080b3b33c892ca552c608b..807ecf764814f1bc16dd87bb3000b0ebf12e5801 100644 (file)
@@ -12,7 +12,7 @@ db = testbase.db
 
 
 users = Table('users', db,
-    Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+    Column('user_id', Integer, Sequence('user_id_seq', optional=False), primary_key = True),
     Column('user_name', String(40)),
 )
 
index 435e26b7628ffc581cf1507b1748427a8ad50060..df4c186c3c2fa5478d616f7779ff7dc739b13cf7 100644 (file)
@@ -81,8 +81,8 @@ class EngineAssert(object):
     """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
     def __init__(self, engine):
         self.engine = engine
-        self.realexec = engine.execute
-        engine.execute = self.execute
+        self.realexec = engine.execute_compiled
+        engine.execute_compiled = self.execute_compiled
         self.echo = engine.echo
         self.logger = engine.logger
         self.set_assert_list(None, None)
@@ -93,9 +93,10 @@ class EngineAssert(object):
         self.assert_list = list
         if list is not None:
             self.assert_list.reverse()
-    def execute(self, statement, parameters, **kwargs):
+    def execute_compiled(self, compiled, parameters, **kwargs):
         self.engine.echo = self.echo
         self.engine.logger = self.logger
+        statement = str(compiled)
         
         if self.assert_list is not None:
             item = self.assert_list.pop()
@@ -104,14 +105,7 @@ class EngineAssert(object):
                 params = params()
 
             # deal with paramstyles of different engines
-            if isinstance(self.engine, sqlite.SQLiteSQLEngine):
-                paramstyle = 'named'
-            else:
-                db = self.engine.dbapi()
-                if db is not None:
-                    paramstyle = db.paramstyle
-                else:
-                    paramstyle = 'named'
+            paramstyle = self.engine.paramstyle
             if paramstyle == 'named':
                 pass
             elif paramstyle =='pyformat':
@@ -127,31 +121,10 @@ class EngineAssert(object):
                 elif paramstyle=='numeric':
                     repl = None
                 counter = 0
-                def append_arg(match):
-                    names.append(match.group(1))
-                    if repl is None:
-                        counter += 1
-                        return counter
-                    else:
-                        return repl
-                # substitute bind string in query, translate bind param
-                # dict to a list (or a list of dicts to a list of lists)
-                query = re.sub(r':([\w_]+)', append_arg, query)
-                if isinstance(params, list):
-                    args = []
-                    for p in params:
-                        l = []
-                        args.append(l)
-                        for n in names:
-                            l.append(p[n])
-                else:
-                    args = []
-                    for n in names:
-                        args.append(params[n])
-                params = args
+                query = re.sub(r':([\w_]+)', repl, query)
             
             self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
-        return self.realexec(statement, parameters, **kwargs)
+        return self.realexec(compiled, parameters, **kwargs)
 
 
 class TTestSuite(unittest.TestSuite):