]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- removed "parameters" argument from clauseelement.compile(), replaced with
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Sep 2007 00:08:57 +0000 (00:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Sep 2007 00:08:57 +0000 (00:08 +0000)
  "column_keys".  the parameters sent to execute() only interact with the
  insert/update statement compilation process in terms of the column names
  present but not the values for those columns.
  produces more consistent execute/executemany behavior, simplifies things a
  bit internally.

CHANGES
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/query.py
test/sql/select.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index 5a956f8ecebfead2460e7fef60fb3ad563751da6..9596524d5a2ed18d110bf468d028b56d7737ffb9 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -13,6 +13,13 @@ CHANGES
   so mappers within inheritance relationships need to be constructed in
   inheritance order (which should be the normal case anyway).
 
+- removed "parameters" argument from clauseelement.compile(), replaced with
+  "column_keys".  the parameters sent to execute() only interact with the 
+  insert/update statement compilation process in terms of the column names 
+  present but not the values for those columns.
+  produces more consistent execute/executemany behavior, simplifies things a 
+  bit internally.
+  
 0.4.0beta5
 ----------
 
index 3e8949e14f4019929e9ad163c5678394f1575ed7..7b0defc2ab3e571db25d9169657b1932fc38c156 100644 (file)
@@ -847,8 +847,8 @@ dialect_mapping = {
 
 
 class MSSQLCompiler(compiler.DefaultCompiler):
-    def __init__(self, dialect, statement, parameters, **kwargs):
-        super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
+    def __init__(self, *args, **kwargs):
+        super(MSSQLCompiler, self).__init__(*args, **kwargs)
         self.tablealiases = {}
 
     def get_select_precolumns(self, select):
index bd6f5b97cf4c491381ebbde110348b8f68b5becc..32bf7b780fe5adb93edc88c23314eee3624eb3ba 100644 (file)
@@ -421,7 +421,7 @@ class Compiled(object):
     defaults.
     """
 
-    def __init__(self, dialect, statement, parameters, bind=None):
+    def __init__(self, dialect, statement, column_keys=None, bind=None):
         """Construct a new ``Compiled`` object.
 
         dialect
@@ -430,26 +430,16 @@ class Compiled(object):
         statement
           ``ClauseElement`` to be compiled.
 
-        parameters
-          Optional dictionary indicating a set of bind parameters
-          specified with this ``Compiled`` object.  These parameters
-          are the *default* values corresponding to the
-          ``ClauseElement``'s ``_BindParamClauses`` when the
-          ``Compiled`` is executed.  In the case of an ``INSERT`` or
-          ``UPDATE`` statement, these parameters will also result in
-          the creation of new ``_BindParamClause`` objects for each
-          key and will also affect the generated column list in an
-          ``INSERT`` statement and the ``SET`` clauses of an
-          ``UPDATE`` statement.  The keys of the parameter dictionary
-          can either be the string names of columns or
-          ``_ColumnClause`` objects.
+        column_keys
+          a list of column names to be compiled into an INSERT or UPDATE
+          statement.
 
         bind
           Optional Engine or Connection to compile this statement against.
         """
         self.dialect = dialect
         self.statement = statement
-        self.parameters = parameters
+        self.column_keys = column_keys
         self.bind = bind
         self.can_execute = statement.supports_execution()
     
@@ -778,8 +768,8 @@ class Connection(Connectable):
 
         return self.execute(object, *multiparams, **params).scalar()
 
-    def statement_compiler(self, statement, parameters, **kwargs):
-        return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
+    def statement_compiler(self, statement, **kwargs):
+        return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
 
     def execute(self, object, *multiparams, **params):
         """Executes and returns a ResultProxy."""
@@ -808,25 +798,43 @@ class Connection(Connectable):
             parameters = list(multiparams)
         return parameters
 
+    def __distill_params_and_keys(self, multiparams, params):
+        if multiparams is None or len(multiparams) == 0:
+            if params:
+                parameters = params
+                keys = params.keys()
+            else:
+                parameters = None
+                keys = []
+            executemany = False
+        elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)):
+            parameters = multiparams[0]
+            if isinstance(parameters, dict):
+                keys = parameters.keys()
+            else:
+                keys = parameters[0].keys()
+            executemany = False
+        else:
+            parameters = list(multiparams)
+            keys = parameters[0].keys()
+            executemany = True
+        return (parameters, keys, executemany)
+
     def _execute_function(self, func, multiparams, params):
         return self._execute_clauseelement(func.select(), multiparams, params)
 
     def _execute_clauseelement(self, elem, multiparams=None, params=None):
-        if multiparams:
-            param = multiparams[0]
-            executemany = len(multiparams) > 1
-        else:
-            param = params
-            executemany = False
-        return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param, inline=executemany), multiparams, params)
+        (params, keys, executemany) = self.__distill_params_and_keys(multiparams, params)
+        return self._execute_compiled(elem.compile(dialect=self.dialect, column_keys=keys, inline=executemany), distilled_params=params)
 
-    def _execute_compiled(self, compiled, multiparams=None, params=None):
+    def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None):
         """Execute a sql.Compiled object."""
         if not compiled.can_execute:
             raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
 
-        params = self.__distill_params(multiparams, params)
-        context = self.__create_execution_context(compiled=compiled, parameters=params)
+        if distilled_params is None:
+            distilled_params = self.__distill_params(multiparams, params)
+        context = self.__create_execution_context(compiled=compiled, parameters=distilled_params)
 
         context.pre_execution()
         self.__execute_raw(context)
@@ -1119,8 +1127,8 @@ class Engine(Connectable):
         connection = self.contextual_connect(close_with_result=True)
         return connection._execute_compiled(compiled, multiparams, params)
 
-    def statement_compiler(self, statement, parameters, **kwargs):
-        return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs)
+    def statement_compiler(self, statement, **kwargs):
+        return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
 
     def connect(self, **kwargs):
         """Return a newly allocated Connection object."""
index 1cfebdc276059a46e152373df54c7dd738062593..eb416803a46c5147093d1b02d8985c53c340c523 100644 (file)
@@ -90,7 +90,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
 
     operators = OPERATORS
     
-    def __init__(self, dialect, statement, parameters=None, inline=False, **kwargs):
+    def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
 
         dialect
@@ -99,16 +99,12 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         statement
           ClauseElement to be compiled
 
-        parameters
-          optional dictionary indicating a set of bind parameters
-          specified with this Compiled object.  These parameters are
-          the *default* key/value pairs when the Compiled is executed,
-          and also may affect the actual compilation, as in the case
-          of an INSERT where the actual columns inserted will
-          correspond to the keys present in the parameters.
+        column_keys
+          a list of column names to be compiled into an INSERT or UPDATE
+          statement.
         """
         
-        super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs)
+        super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs)
 
         # if we are insert/update.  set to true when we visit an INSERT or UPDATE
         self.isinsert = self.isupdate = False
@@ -217,12 +213,10 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         to produce a ClauseParameters structure, representing the bind arguments
         for a single statement execution, or one element of an executemany execution.
         """
-        
+
         d = sql_util.ClauseParameters(self.dialect, self.positiontup)
 
-        pd = self.parameters or {}
-        if params is not None:
-            pd.update(params)
+        pd = params or {}
 
         bind_names = self.bind_names
         for key, bind in self.binds.iteritems():
@@ -658,15 +652,15 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
 
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
-        if self.parameters is None and stmt.parameters is None:
+        if self.column_keys is None and stmt.parameters is None:
             return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
 
         # if we have statement parameters - set defaults in the
         # compiled params
-        if self.parameters is None:
+        if self.column_keys is None:
             parameters = {}
         else:
-            parameters = dict([(getattr(k, 'key', k), v) for k, v in self.parameters.iteritems()])
+            parameters = dict([(getattr(key, 'key', key), None) for key in self.column_keys])
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
index ac56289e88aed283363a7f2445262eacbb08213e..f88c418ebf3b57987acee7849ff49f3c1a83fe60 100644 (file)
@@ -963,19 +963,24 @@ class ClauseElement(object):
 
     def execute(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``."""
-
-        if multiparams:
-            compile_params = multiparams[0]
+        
+        if len(multiparams) == 0:
+            keys = params.keys()
+        elif isinstance(multiparams[0], dict):
+            keys = multiparams[0].keys()
+        elif isinstance(multiparams[0], (list, tuple)):
+            keys = multiparams[0][0].keys()
         else:
-            compile_params = params
-        return self.compile(bind=self.bind, parameters=compile_params, inline=(len(multiparams) > 1)).execute(*multiparams, **params)
+            keys = None
+
+        return self.compile(bind=self.bind, column_keys=keys, inline=(len(multiparams) > 1)).execute(*multiparams, **params)
 
     def scalar(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``, returning the result's scalar representation."""
 
         return self.execute(*multiparams, **params).scalar()
 
-    def compile(self, bind=None, parameters=None, compiler=None, dialect=None, inline=False):
+    def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False):
         """Compile this SQL expression.
 
         Uses the given ``Compiler``, or the given ``AbstractDialect``
@@ -999,21 +1004,18 @@ class ClauseElement(object):
         ``SET`` and ``VALUES`` clause of those statements.
         """
 
-        if isinstance(parameters, (list, tuple)):
-            parameters = parameters[0]
-
         if compiler is None:
             if dialect is not None:
-                compiler = dialect.statement_compiler(dialect, self, parameters, inline=inline)
+                compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
             elif bind is not None:
-                compiler = bind.statement_compiler(self, parameters, inline=inline)
+                compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline)
             elif self.bind is not None:
-                compiler = self.bind.statement_compiler(self, parameters, inline=inline)
+                compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline)
 
         if compiler is None:
             from sqlalchemy.engine.default import DefaultDialect
             dialect = DefaultDialect()
-            compiler = dialect.statement_compiler(dialect, self, parameters=parameters, inline=inline)
+            compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
         compiler.compile()
         return compiler
     
index 4e68fb980b9d22da48835cd270f6e1ccce2d6250..a519dd974bc76084ed163a3e2dc07e5c2e577778 100644 (file)
@@ -32,6 +32,14 @@ class QueryTest(PersistTest):
         users.insert().execute(user_id = 7, user_name = 'jack')
         assert users.count().scalar() == 1
     
+    def test_insert_heterogeneous_params(self):
+        users.insert().execute(
+            {'user_id':7, 'user_name':'jack'},
+            {'user_id':8, 'user_name':'ed'},
+            {'user_id':9}
+        )
+        assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)]
+        
     def testupdate(self):
 
         users.insert().execute(user_id = 7, user_name = 'jack')
@@ -353,9 +361,9 @@ class QueryTest(PersistTest):
         )
         meta.create_all()
         try:
-            t.insert().execute(value=func.length("one"))
+            t.insert(values=dict(value=func.length("one"))).execute()
             assert t.select().execute().fetchone()['value'] == 3
-            t.update().execute(value=func.length("asfda"))
+            t.update(values=dict(value=func.length("asfda"))).execute()
             assert t.select().execute().fetchone()['value'] == 5
 
             r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
@@ -363,14 +371,14 @@ class QueryTest(PersistTest):
             assert t.select(t.c.id==id).execute().fetchone()['value'] == 9
             t.update(values={t.c.value:func.length("asdf")}).execute()
             assert t.select().execute().fetchone()['value'] == 4
-
+            print "--------------------------"
             t2.insert().execute()
-            t2.insert().execute(value=func.length("one"))
-            t2.insert().execute(value=func.length("asfda") + -19, stuff="hi")
+            t2.insert(values=dict(value=func.length("one"))).execute()
+            t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi")
 
             assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(7,None), (3,None), (-14,"hi")]
             
-            t2.update().execute(value=func.length("asdsafasd"), stuff="some stuff")
+            t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff")
             assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")]
             
             t2.delete().execute()
index 5eaea7480ed58b0e31f96101434af081b0e151f0..edca33bc0b60ae04433399b76c252cc52d1ae698 100644 (file)
@@ -567,10 +567,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
 
     def testtextbinds(self):
         self.assert_compile(
-            text("select * from foo where lala=:bar and hoho=:whee"), 
+            text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar', 4), bindparam('whee', 7)]), 
                 "select * from foo where lala=:bar and hoho=:whee", 
                 checkparams={'bar':4, 'whee': 7},
-                params={'bar':4, 'whee': 7, 'hoho':10},
         )
 
         self.assert_compile(
@@ -582,10 +581,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
 
         dialect = postgres.dialect()
         self.assert_compile(
-            text("select * from foo where lala=:bar and hoho=:whee"), 
+            text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), 
                 "select * from foo where lala=%(bar)s and hoho=%(whee)s", 
                 checkparams={'bar':4, 'whee': 7},
-                params={'bar':4, 'whee': 7, 'hoho':10},
                 dialect=dialect
         )
         self.assert_compile(
@@ -598,10 +596,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
 
         dialect = sqlite.dialect()
         self.assert_compile(
-            text("select * from foo where lala=:bar and hoho=:whee"), 
+            text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), 
                 "select * from foo where lala=? and hoho=?", 
                 checkparams=[4, 7],
-                params={'bar':4, 'whee': 7, 'hoho':10},
                 dialect=dialect
         )
         
@@ -936,11 +933,6 @@ EXISTS (select yay from foo where boo = lar)",
         except exceptions.CompileError, err:
             assert str(err) == "Bind parameter 'mytable_myid_1' conflicts with unique bind parameter of the same name"
             
-        # check that the bind params sent along with a compile() call
-        # get preserved when the params are retreived later
-        s = select([table1], table1.c.myid == bindparam('test'))
-        c = s.compile(parameters = {'test' : 7})
-        self.assert_(c.get_params().get_original_dict() == {'test' : 7})
 
     def testbindascol(self):
         t = table('foo', column('id'))
@@ -1134,7 +1126,7 @@ class CRUDTest(SQLCompileTest):
         self.assert_compile(table.insert(inline=True), "INSERT INTO sometable (foo) VALUES (foobar())", params={})    
             
     def testinsertexpression(self):
-        self.assert_compile(insert(table1), "INSERT INTO mytable (myid) VALUES (lala())", params=dict(myid=func.lala()))
+        self.assert_compile(insert(table1, values=dict(myid=func.lala())), "INSERT INTO mytable (myid) VALUES (lala())")
         
     def testupdate(self):
         self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table1.c.name:'fred'})
@@ -1144,7 +1136,7 @@ class CRUDTest(SQLCompileTest):
         self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
         self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
         s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'})
-        c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'})
+        c = s.compile(column_keys=['mytable_id', 'name'])
         self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}).values({table1.c.name:table1.c.name + 'foo'}), "UPDATE mytable SET name=(mytable.name || :mytable_name), description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
         self.assert_(str(s) == str(c))
         
index 0038fddfe916f12b41b5858f70d846364ad10417..26873b25f611ae5d2f0096ce068b9c0879b7111f 100644 (file)
@@ -211,8 +211,13 @@ class SQLCompileTest(PersistTest):
     def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None):
         if dialect is None:
             dialect = getattr(self, '__dialect__', None)
-            
-        c = clause.compile(parameters=params, dialect=dialect)
+        
+        if params is None:
+            keys = None
+        else:
+            keys = params.keys()
+                
+        c = clause.compile(column_keys=keys, dialect=dialect)
 
         print "\nSQL String:\n" + str(c) + repr(c.get_params())