]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
move execute parameter processing from sql.ClauseElement to engine.execute_compiled
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Dec 2005 01:37:10 +0000 (01:37 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Dec 2005 01:37:10 +0000 (01:37 +0000)
testbase gets "assert_sql_count" method, moves execution wrapping to pre_exec to accomodate engine change
move _get_colparams from Insert/Update to ansisql since it applies to compilation
ansisql also insures that select list for columns is unique, helps the mapper with the "distinct" keyword
docstrings/cleanup

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/sql.py
test/testbase.py

index d8d2662ba5f18ee855a248bf5e33372d69f20304..7a90e746a093eaa78e0e0cfeee538882be082cba 100644 (file)
@@ -244,23 +244,32 @@ class ANSICompiler(sql.Compiled):
         self.strings[alias] = self.get_str(alias.selectable)
 
     def visit_select(self, select):
-        inner_columns = []
-
+        
+        # the actual list of columns to print in the SELECT column list.
+        # its an ordered dictionary to insure that the actual labeled column name
+        # is unique.
+        inner_columns = OrderedDict()
+        def col_key(c):
+            if select.use_labels:
+                return c.label
+            else:
+                return self.get_str(c)
+                
         self.select_stack.append(select)
         for c in select._raw_columns:
             if c.is_selectable():
                 for co in c.columns:
                     co.accept_visitor(self)
-                    inner_columns.append(co)
+                    inner_columns[col_key(co)] = co
             else:
                 c.accept_visitor(self)
-                inner_columns.append(c)
+                inner_columns[col_key(c)] = c
         self.select_stack.pop(-1)
         
         if select.use_labels:
-            collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ')
+            collist = string.join(["%s AS %s" % (self.get_str(v), k) for k, v in inner_columns.iteritems()], ', ')
         else:
-            collist = string.join([self.get_str(c) for c in inner_columns], ', ')
+            collist = string.join([k for k in inner_columns.keys()], ', ')
 
         text = "SELECT "
         if select.distinct:
@@ -275,7 +284,7 @@ class ANSICompiler(sql.Compiled):
         # matching those keys
         if self.parameters is not None:
             revisit = False
-            for c in inner_columns:
+            for c in inner_columns.values():
                 if self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
                     value = self.parameters[c.key]
                 elif self.parameters.has_key(c.label) and not self.binds.has_key(c.label):
@@ -377,7 +386,7 @@ class ANSICompiler(sql.Compiled):
                 c.default.accept_visitor(vis)
         
         self.isinsert = True
-        colparams = insert_stmt.get_colparams(self.parameters)
+        colparams = self._get_colparams(insert_stmt)
         for c in colparams:
             b = c[1]
             self.binds[b.key] = b
@@ -389,7 +398,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[insert_stmt] = text
 
     def visit_update(self, update_stmt):
-        colparams = update_stmt.get_colparams(self.parameters)
+        colparams = self._get_colparams(update_stmt)
         def create_param(p):
             if isinstance(p, sql.BindParamClause):
                 self.binds[p.key] = p
@@ -409,6 +418,59 @@ class ANSICompiler(sql.Compiled):
          
         self.strings[update_stmt] = text
 
+
+    def _get_colparams(self, stmt):
+        """determines the VALUES or SET clause for an INSERT or UPDATE
+        clause based on the arguments specified to this ANSICompiler object
+        (i.e., the execute() or compile() method clause object):
+
+        insert(mytable).execute(col1='foo', col2='bar')
+        mytable.update().execute(col2='foo', col3='bar')
+
+        in the above examples, the insert() and update() methods have no "values" sent to them
+        at all, so compiling them with no arguments would yield an insert for all table columns,
+        or an update with no SET clauses.  but the parameters sent indicate a set of per-compilation
+        arguments that result in a differently compiled INSERT or UPDATE object compared to the
+        original.  The "values" parameter to the insert/update is figured as well if present,
+        but the incoming "parameters" sent here take precedence.
+        """
+        # case one: no parameters in the statement, no parameters in the 
+        # compiled params - just return binds for all the table columns
+        if self.parameters is None and stmt.parameters is None:
+            return [(c, bindparam(c.name, type=c.type)) for c in stmt.table.columns]
+
+        # if we have statement parameters - set defaults in the 
+        # compiled params
+        if self.parameters is None:
+            parameters = {}
+        else:
+            parameters = self.parameters.copy()
+
+        if stmt.parameters is not None:
+            for k, v in stmt.parameters.iteritems():
+                parameters.setdefault(k, v)
+
+        # now go thru compiled params, get the Column object for each key
+        d = {}
+        for key, value in parameters.iteritems():
+            if isinstance(key, schema.Column):
+                d[key] = value
+            else:
+                try:
+                    d[stmt.table.columns[str(key)]] = value
+                except KeyError:
+                    pass
+
+        # create a list of column assignment clauses as tuples
+        values = []
+        for c in stmt.table.columns:
+            if d.has_key(c):
+                value = d[c]
+                if sql._is_literal(value):
+                    value = bindparam(c.name, value, type=c.type)
+                values.append((c, value))
+        return values
+
     def visit_delete(self, delete_stmt):
         text = "DELETE FROM " + delete_stmt.table.fullname
         
index 349bb4d1d6da3bc30d8c9913a1a034a4bf41c265..1e5ba34d020aca50650dffba0bc2391501cb2757 100644 (file)
@@ -324,10 +324,11 @@ class SQLEngine(schema.SchemaEngine):
         pass
 
     def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs):
-        """executes the given string-based SQL statement with the given parameters.  
+        """executes the given compiled statement object with the given parameters.  
 
-        The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
-        on the paramstyle of the DBAPI.
+        The parameters can be a dictionary of key/value pairs, or a list of dictionaries for an
+        executemany() style of execution.  Engines that use positional parameters will convert
+        the parameters to a list before execution.
 
         If the current thread has specified a transaction begin() for this engine, the
         statement will be executed in the context of the current transactional connection.
@@ -360,6 +361,12 @@ class SQLEngine(schema.SchemaEngine):
         if cursor is None:
             cursor = connection.cursor()
 
+        executemany = parameters is not None and (isinstance(parameters, list) or isinstance(parameters, tuple))
+        if executemany:
+            parameters = [compiled.get_params(**m) for m in parameters]
+        else:
+            parameters = compiled.get_params(**parameters)
+        
         def proxy(statement=None, parameters=None):
             if statement is None:
                 return cursor
@@ -371,7 +378,7 @@ class SQLEngine(schema.SchemaEngine):
                     parameters = [p.values() for p in parameters]
                 else:
                     parameters = parameters.values()
-            
+
             self.execute(statement, parameters, connection=connection, cursor=cursor)        
             return cursor
 
index a634767eaa0648d988dc668996cea8cd5e6e7cd1..3a248f43471e9315bf058019af2c11eba039b1ef 100644 (file)
@@ -15,7 +15,6 @@
 # along with this library; if not, write to the Free Software
 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 
-
 """defines the base components of SQL expression trees."""
 
 import sqlalchemy.schema as schema
@@ -270,10 +269,8 @@ class Compiled(ClauseVisitor):
     def execute(self, *multiparams, **params):
         """executes this compiled object using the underlying SQLEngine"""
         if len(multiparams):
-            params = [self.get_params(**m) for m in multiparams]
-        else:
-            params = self.get_params(**params)
-
+            params = multiparams
+            
         return self.engine.execute_compiled(self, params)
 
     def scalar(self, *multiparams, **params):
@@ -447,56 +444,50 @@ class CompareMixin(object):
         return BinaryClause(self, obj, operator)
 
 class FromClause(ClauseElement):
-    """represents a FROM clause element in a SQL statement."""
-    
+    """represents an element within the FROM clause of a SELECT statement."""
     def __init__(self, from_name = None, from_key = None):
         self.from_name = from_name
         self.id = from_key or from_name
-        
     def _get_from_objects(self):
         # this could also be [self], at the moment it doesnt matter to the Select object
         return []
-        
     def hash_key(self):
         return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name))
-            
     def accept_visitor(self, visitor): 
         visitor.visit_fromclause(self)
     
 class BindParamClause(ClauseElement, CompareMixin):
+    """represents a bind parameter.  public constructor is the bindparam() function."""
     def __init__(self, key, value, shortname = None, type = None):
         self.key = key
         self.value = value
         self.shortname = shortname
         self.type = type or types.NULLTYPE
-
     def accept_visitor(self, visitor):
         visitor.visit_bindparam(self)
-
     def _get_from_objects(self):
         return []
-     
     def hash_key(self):
         return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname))
-
     def typeprocess(self, value):
         return self.type.convert_bind_param(value)
             
 class TextClause(ClauseElement):
-    """represents literal text, including SQL fragments as well
-    as literal (non bind-param) values."""
+    """represents literal a SQL text fragment.  public constructor is the 
+    text() function.  
     
-    def __init__(self, text = "", engine=None, isliteral=False):
+    TextClauses, since they can be anything, have no comparison operators or
+    typing information.
+      
+    A single literal value within a compiled SQL statement is more useful 
+    being specified as a bind parameter via the bindparam() method,
+    since it provides more information about what it is, including an optional
+    type, as well as providing comparison operations."""
+    def __init__(self, text = "", engine=None):
         self.text = text
         self.parens = False
         self._engine = engine
         self.id = id(self)
-        if isliteral:
-            if isinstance(text, int) or isinstance(text, long):
-                self.text = str(text)
-            else:
-                text = re.sub(r"'", r"''", text)
-                self.text = "'" + text + "'"
     def accept_visitor(self, visitor): 
         visitor.visit_textclause(self)
     def hash_key(self):
@@ -505,6 +496,8 @@ class TextClause(ClauseElement):
         return []
 
 class Null(ClauseElement):
+    """represents the NULL keyword in a SQL statement. public contstructor is the
+    null() function."""
     def accept_visitor(self, visitor):
         visitor.visit_null(self)
     def _get_from_objects(self):
@@ -856,8 +849,8 @@ class TableImpl(Selectable):
         self._rowid_column._set_parent(table)
     
     rowid_column = property(lambda s: s._rowid_column)
-    
     engine = property(lambda s: s.table.engine)
+    columns = property(lambda self: self.table.columns)
 
     def _get_col_by_original(self, column):
         try:
@@ -880,35 +873,24 @@ class TableImpl(Selectable):
     
     def join(self, right, *args, **kwargs):
         return Join(self.table, right, *args, **kwargs)
-    
     def outerjoin(self, right, *args, **kwargs):
         return Join(self.table, right, isouter = True, *args, **kwargs)
-
     def alias(self, name):
         return Alias(self.table, name)
-            
     def select(self, whereclause = None, **params):
         return select([self.table], whereclause, **params)
-
     def insert(self, values = None):
         return insert(self.table, values=values)
-
     def update(self, whereclause = None, values = None):
         return update(self.table, whereclause, values)
-
     def delete(self, whereclause = None):
         return delete(self.table, whereclause)
-        
-    columns = property(lambda self: self.table.columns)
-
-    def _get_from_objects(self):
-        return [self.table]
-
     def create(self, **params):
         self.table.engine.create(self.table)
-
     def drop(self, **params):
         self.table.engine.drop(self.table)
+    def _get_from_objects(self):
+        return [self.table]
 
 class SelectBaseMixin(object):
     """base class for Select and CompoundSelects"""
@@ -1091,6 +1073,10 @@ class Select(SelectBaseMixin, Selectable):
     froms = property(lambda s: s._get_froms())
 
     def accept_visitor(self, visitor):
+        # TODO: add contextual visit_ methods
+        # visit_select_whereclause, visit_select_froms, visit_select_orderby, etc.
+        # which will allow the compiler to set contextual flags before traversing 
+        # into each thing.  
         for f in self._get_froms():
             f.accept_visitor(visitor)
         if self.whereclause is not None:
@@ -1118,16 +1104,13 @@ class Select(SelectBaseMixin, Selectable):
                 self._engine = e
                 return e
         return None
-    
-
 
 class UpdateBase(ClauseElement):
-    """forms the base for INSERT, UPDATE, and DELETE statements.  
-    Deals with the special needs of INSERT and UPDATE parameter lists -  
-    these statements have two separate lists of parameters, those
-    defined when the statement is constructed, and those specified at compile time."""
+    """forms the base for INSERT, UPDATE, and DELETE statements."""
     
     def _process_colparams(self, parameters):
+        """receives the "values" of an INSERT or UPDATE statement and constructs
+        appropriate ind parameters."""
         if parameters is None:
             return None
 
@@ -1154,57 +1137,6 @@ class UpdateBase(ClauseElement):
                     del parameters[key]
         return parameters
         
-    def get_colparams(self, parameters):
-        """this is used by the ANSICompiler to determine the VALUES or SET clause based on the arguments 
-        specified to the execute() or compile() method of the INSERT or UPDATE clause:
-        
-        insert(mytable).execute(col1='foo', col2='bar')
-        mytable.update().execute(col2='foo', col3='bar')
-        
-        in the above examples, the insert() and update() methods have no "values" sent to them
-        at all, so compiling them with no arguments would yield an insert for all table columns,
-        or an update with no SET clauses.  but the parameters sent indicate a set of per-compilation
-        arguments that result in a differently compiled INSERT or UPDATE object compared to the
-        original.  The "values" parameter to the insert/update is figured as well if present,
-        but the incoming "parameters" sent here take precedence.
-        """
-        # case one: no parameters in the statement, no parameters in the 
-        # compiled params - just return binds for all the table columns
-        if parameters is None and self.parameters is None:
-            return [(c, bindparam(c.name, type=c.type)) for c in self.table.columns]
-
-        # if we have statement parameters - set defaults in the 
-        # compiled params
-        if parameters is None:
-            parameters = {}
-        else:
-            parameters = parameters.copy()
-            
-        if self.parameters is not None:
-            for k, v in self.parameters.iteritems():
-                parameters.setdefault(k, v)
-
-        # now go thru compiled params, get the Column object for each key
-        d = {}
-        for key, value in parameters.iteritems():
-            if isinstance(key, schema.Column):
-                d[key] = value
-            else:
-                try:
-                    d[self.table.columns[str(key)]] = value
-                except KeyError:
-                    pass
-
-        # create a list of column assignment clauses as tuples
-        values = []
-        for c in self.table.columns:
-            if d.has_key(c):
-                value = d[c]
-                if _is_literal(value):
-                    value = bindparam(c.name, value, type=c.type)
-                values.append((c, value))
-        return values
-
 
 class Insert(UpdateBase):
     def __init__(self, table, values=None, **params):
index ab32f803feeb7fc6a9791d923f119fa30285ed82..d08b585277f4adaa96093ecb22511685ed7803a7 100644 (file)
@@ -76,15 +76,22 @@ class AssertMixin(PersistTest):
             callable_()
         finally:
             db.set_assert_list(None, None)
+    def assert_sql_count(self, db, callable_, count):
+        db.sql_count = 0
+        try:
+            callable_()
+        finally:
+            self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count))
         
 class EngineAssert(object):
     """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
     def __init__(self, engine):
         self.engine = engine
-        self.realexec = engine.execute_compiled
-        engine.execute_compiled = self.execute_compiled
+        self.realexec = engine.pre_exec
+        engine.pre_exec = self.pre_exec
         self.logger = engine.logger
         self.set_assert_list(None, None)
+        self.sql_count = 0
     def __getattr__(self, key):
         return getattr(self.engine, key)
     def set_assert_list(self, unittest, list):
@@ -92,15 +99,14 @@ class EngineAssert(object):
         self.assert_list = list
         if list is not None:
             self.assert_list.reverse()
-
     def _set_echo(self, echo):
         self.engine.echo = echo
     echo = property(lambda s: s.engine.echo, _set_echo)
-    def execute_compiled(self, compiled, parameters, **kwargs):
+    def pre_exec(self, proxy, compiled, parameters, **kwargs):
         self.engine.logger = self.logger
         statement = str(compiled)
         statement = re.sub(r'\n', '', statement)
-        
+
         if self.assert_list is not None:
             item = self.assert_list.pop()
             (query, params) = item
@@ -127,7 +133,8 @@ class EngineAssert(object):
                 query = re.sub(r':([\w_]+)', repl, query)
 
             self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
-        return self.realexec(compiled, parameters, **kwargs)
+        self.sql_count += 1
+        return self.realexec(proxy, compiled, parameters, **kwargs)
 
 
 class TTestSuite(unittest.TestSuite):