From dbd407d62ac3cbf6e54de7499f1a95b54e3e4204 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 23 Dec 2005 01:37:10 +0000 Subject: [PATCH] move execute parameter processing from sql.ClauseElement to engine.execute_compiled testbase gets "assert_sql_count" method, moves execution wrapping to pre_exec to accomodate engine change move _get_colparams from Insert/Update to ansisql since it applies to compilation ansisql also insures that select list for columns is unique, helps the mapper with the "distinct" keyword docstrings/cleanup --- lib/sqlalchemy/ansisql.py | 80 ++++++++++++++++++++++--- lib/sqlalchemy/engine.py | 15 +++-- lib/sqlalchemy/sql.py | 120 +++++++++----------------------------- test/testbase.py | 19 ++++-- 4 files changed, 121 insertions(+), 113 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index d8d2662ba5..7a90e746a0 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -244,23 +244,32 @@ class ANSICompiler(sql.Compiled): self.strings[alias] = self.get_str(alias.selectable) def visit_select(self, select): - inner_columns = [] - + + # the actual list of columns to print in the SELECT column list. + # its an ordered dictionary to insure that the actual labeled column name + # is unique. + inner_columns = OrderedDict() + def col_key(c): + if select.use_labels: + return c.label + else: + return self.get_str(c) + self.select_stack.append(select) for c in select._raw_columns: if c.is_selectable(): for co in c.columns: co.accept_visitor(self) - inner_columns.append(co) + inner_columns[col_key(co)] = co else: c.accept_visitor(self) - inner_columns.append(c) + inner_columns[col_key(c)] = c self.select_stack.pop(-1) if select.use_labels: - collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ') + collist = string.join(["%s AS %s" % (self.get_str(v), k) for k, v in inner_columns.iteritems()], ', ') else: - collist = string.join([self.get_str(c) for c in inner_columns], ', ') + collist = string.join([k for k in inner_columns.keys()], ', ') text = "SELECT " if select.distinct: @@ -275,7 +284,7 @@ class ANSICompiler(sql.Compiled): # matching those keys if self.parameters is not None: revisit = False - for c in inner_columns: + for c in inner_columns.values(): if self.parameters.has_key(c.key) and not self.binds.has_key(c.key): value = self.parameters[c.key] elif self.parameters.has_key(c.label) and not self.binds.has_key(c.label): @@ -377,7 +386,7 @@ class ANSICompiler(sql.Compiled): c.default.accept_visitor(vis) self.isinsert = True - colparams = insert_stmt.get_colparams(self.parameters) + colparams = self._get_colparams(insert_stmt) for c in colparams: b = c[1] self.binds[b.key] = b @@ -389,7 +398,7 @@ class ANSICompiler(sql.Compiled): self.strings[insert_stmt] = text def visit_update(self, update_stmt): - colparams = update_stmt.get_colparams(self.parameters) + colparams = self._get_colparams(update_stmt) def create_param(p): if isinstance(p, sql.BindParamClause): self.binds[p.key] = p @@ -409,6 +418,59 @@ class ANSICompiler(sql.Compiled): self.strings[update_stmt] = text + + def _get_colparams(self, stmt): + """determines the VALUES or SET clause for an INSERT or UPDATE + clause based on the arguments specified to this ANSICompiler object + (i.e., the execute() or compile() method clause object): + + insert(mytable).execute(col1='foo', col2='bar') + mytable.update().execute(col2='foo', col3='bar') + + in the above examples, the insert() and update() methods have no "values" sent to them + at all, so compiling them with no arguments would yield an insert for all table columns, + or an update with no SET clauses. but the parameters sent indicate a set of per-compilation + arguments that result in a differently compiled INSERT or UPDATE object compared to the + original. The "values" parameter to the insert/update is figured as well if present, + but the incoming "parameters" sent here take precedence. + """ + # case one: no parameters in the statement, no parameters in the + # compiled params - just return binds for all the table columns + if self.parameters is None and stmt.parameters is None: + return [(c, bindparam(c.name, type=c.type)) for c in stmt.table.columns] + + # if we have statement parameters - set defaults in the + # compiled params + if self.parameters is None: + parameters = {} + else: + parameters = self.parameters.copy() + + if stmt.parameters is not None: + for k, v in stmt.parameters.iteritems(): + parameters.setdefault(k, v) + + # now go thru compiled params, get the Column object for each key + d = {} + for key, value in parameters.iteritems(): + if isinstance(key, schema.Column): + d[key] = value + else: + try: + d[stmt.table.columns[str(key)]] = value + except KeyError: + pass + + # create a list of column assignment clauses as tuples + values = [] + for c in stmt.table.columns: + if d.has_key(c): + value = d[c] + if sql._is_literal(value): + value = bindparam(c.name, value, type=c.type) + values.append((c, value)) + return values + def visit_delete(self, delete_stmt): text = "DELETE FROM " + delete_stmt.table.fullname diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 349bb4d1d6..1e5ba34d02 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -324,10 +324,11 @@ class SQLEngine(schema.SchemaEngine): pass def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs): - """executes the given string-based SQL statement with the given parameters. + """executes the given compiled statement object with the given parameters. - The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending - on the paramstyle of the DBAPI. + The parameters can be a dictionary of key/value pairs, or a list of dictionaries for an + executemany() style of execution. Engines that use positional parameters will convert + the parameters to a list before execution. If the current thread has specified a transaction begin() for this engine, the statement will be executed in the context of the current transactional connection. @@ -360,6 +361,12 @@ class SQLEngine(schema.SchemaEngine): if cursor is None: cursor = connection.cursor() + executemany = parameters is not None and (isinstance(parameters, list) or isinstance(parameters, tuple)) + if executemany: + parameters = [compiled.get_params(**m) for m in parameters] + else: + parameters = compiled.get_params(**parameters) + def proxy(statement=None, parameters=None): if statement is None: return cursor @@ -371,7 +378,7 @@ class SQLEngine(schema.SchemaEngine): parameters = [p.values() for p in parameters] else: parameters = parameters.values() - + self.execute(statement, parameters, connection=connection, cursor=cursor) return cursor diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a634767eaa..3a248f4347 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -15,7 +15,6 @@ # along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. - """defines the base components of SQL expression trees.""" import sqlalchemy.schema as schema @@ -270,10 +269,8 @@ class Compiled(ClauseVisitor): def execute(self, *multiparams, **params): """executes this compiled object using the underlying SQLEngine""" if len(multiparams): - params = [self.get_params(**m) for m in multiparams] - else: - params = self.get_params(**params) - + params = multiparams + return self.engine.execute_compiled(self, params) def scalar(self, *multiparams, **params): @@ -447,56 +444,50 @@ class CompareMixin(object): return BinaryClause(self, obj, operator) class FromClause(ClauseElement): - """represents a FROM clause element in a SQL statement.""" - + """represents an element within the FROM clause of a SELECT statement.""" def __init__(self, from_name = None, from_key = None): self.from_name = from_name self.id = from_key or from_name - def _get_from_objects(self): # this could also be [self], at the moment it doesnt matter to the Select object return [] - def hash_key(self): return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name)) - def accept_visitor(self, visitor): visitor.visit_fromclause(self) class BindParamClause(ClauseElement, CompareMixin): + """represents a bind parameter. public constructor is the bindparam() function.""" def __init__(self, key, value, shortname = None, type = None): self.key = key self.value = value self.shortname = shortname self.type = type or types.NULLTYPE - def accept_visitor(self, visitor): visitor.visit_bindparam(self) - def _get_from_objects(self): return [] - def hash_key(self): return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname)) - def typeprocess(self, value): return self.type.convert_bind_param(value) class TextClause(ClauseElement): - """represents literal text, including SQL fragments as well - as literal (non bind-param) values.""" + """represents literal a SQL text fragment. public constructor is the + text() function. - def __init__(self, text = "", engine=None, isliteral=False): + TextClauses, since they can be anything, have no comparison operators or + typing information. + + A single literal value within a compiled SQL statement is more useful + being specified as a bind parameter via the bindparam() method, + since it provides more information about what it is, including an optional + type, as well as providing comparison operations.""" + def __init__(self, text = "", engine=None): self.text = text self.parens = False self._engine = engine self.id = id(self) - if isliteral: - if isinstance(text, int) or isinstance(text, long): - self.text = str(text) - else: - text = re.sub(r"'", r"''", text) - self.text = "'" + text + "'" def accept_visitor(self, visitor): visitor.visit_textclause(self) def hash_key(self): @@ -505,6 +496,8 @@ class TextClause(ClauseElement): return [] class Null(ClauseElement): + """represents the NULL keyword in a SQL statement. public contstructor is the + null() function.""" def accept_visitor(self, visitor): visitor.visit_null(self) def _get_from_objects(self): @@ -856,8 +849,8 @@ class TableImpl(Selectable): self._rowid_column._set_parent(table) rowid_column = property(lambda s: s._rowid_column) - engine = property(lambda s: s.table.engine) + columns = property(lambda self: self.table.columns) def _get_col_by_original(self, column): try: @@ -880,35 +873,24 @@ class TableImpl(Selectable): def join(self, right, *args, **kwargs): return Join(self.table, right, *args, **kwargs) - def outerjoin(self, right, *args, **kwargs): return Join(self.table, right, isouter = True, *args, **kwargs) - def alias(self, name): return Alias(self.table, name) - def select(self, whereclause = None, **params): return select([self.table], whereclause, **params) - def insert(self, values = None): return insert(self.table, values=values) - def update(self, whereclause = None, values = None): return update(self.table, whereclause, values) - def delete(self, whereclause = None): return delete(self.table, whereclause) - - columns = property(lambda self: self.table.columns) - - def _get_from_objects(self): - return [self.table] - def create(self, **params): self.table.engine.create(self.table) - def drop(self, **params): self.table.engine.drop(self.table) + def _get_from_objects(self): + return [self.table] class SelectBaseMixin(object): """base class for Select and CompoundSelects""" @@ -1091,6 +1073,10 @@ class Select(SelectBaseMixin, Selectable): froms = property(lambda s: s._get_froms()) def accept_visitor(self, visitor): + # TODO: add contextual visit_ methods + # visit_select_whereclause, visit_select_froms, visit_select_orderby, etc. + # which will allow the compiler to set contextual flags before traversing + # into each thing. for f in self._get_froms(): f.accept_visitor(visitor) if self.whereclause is not None: @@ -1118,16 +1104,13 @@ class Select(SelectBaseMixin, Selectable): self._engine = e return e return None - - class UpdateBase(ClauseElement): - """forms the base for INSERT, UPDATE, and DELETE statements. - Deals with the special needs of INSERT and UPDATE parameter lists - - these statements have two separate lists of parameters, those - defined when the statement is constructed, and those specified at compile time.""" + """forms the base for INSERT, UPDATE, and DELETE statements.""" def _process_colparams(self, parameters): + """receives the "values" of an INSERT or UPDATE statement and constructs + appropriate ind parameters.""" if parameters is None: return None @@ -1154,57 +1137,6 @@ class UpdateBase(ClauseElement): del parameters[key] return parameters - def get_colparams(self, parameters): - """this is used by the ANSICompiler to determine the VALUES or SET clause based on the arguments - specified to the execute() or compile() method of the INSERT or UPDATE clause: - - insert(mytable).execute(col1='foo', col2='bar') - mytable.update().execute(col2='foo', col3='bar') - - in the above examples, the insert() and update() methods have no "values" sent to them - at all, so compiling them with no arguments would yield an insert for all table columns, - or an update with no SET clauses. but the parameters sent indicate a set of per-compilation - arguments that result in a differently compiled INSERT or UPDATE object compared to the - original. The "values" parameter to the insert/update is figured as well if present, - but the incoming "parameters" sent here take precedence. - """ - # case one: no parameters in the statement, no parameters in the - # compiled params - just return binds for all the table columns - if parameters is None and self.parameters is None: - return [(c, bindparam(c.name, type=c.type)) for c in self.table.columns] - - # if we have statement parameters - set defaults in the - # compiled params - if parameters is None: - parameters = {} - else: - parameters = parameters.copy() - - if self.parameters is not None: - for k, v in self.parameters.iteritems(): - parameters.setdefault(k, v) - - # now go thru compiled params, get the Column object for each key - d = {} - for key, value in parameters.iteritems(): - if isinstance(key, schema.Column): - d[key] = value - else: - try: - d[self.table.columns[str(key)]] = value - except KeyError: - pass - - # create a list of column assignment clauses as tuples - values = [] - for c in self.table.columns: - if d.has_key(c): - value = d[c] - if _is_literal(value): - value = bindparam(c.name, value, type=c.type) - values.append((c, value)) - return values - class Insert(UpdateBase): def __init__(self, table, values=None, **params): diff --git a/test/testbase.py b/test/testbase.py index ab32f803fe..d08b585277 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -76,15 +76,22 @@ class AssertMixin(PersistTest): callable_() finally: db.set_assert_list(None, None) + def assert_sql_count(self, db, callable_, count): + db.sql_count = 0 + try: + callable_() + finally: + self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count)) class EngineAssert(object): """decorates a SQLEngine object to match the incoming queries against a set of assertions.""" def __init__(self, engine): self.engine = engine - self.realexec = engine.execute_compiled - engine.execute_compiled = self.execute_compiled + self.realexec = engine.pre_exec + engine.pre_exec = self.pre_exec self.logger = engine.logger self.set_assert_list(None, None) + self.sql_count = 0 def __getattr__(self, key): return getattr(self.engine, key) def set_assert_list(self, unittest, list): @@ -92,15 +99,14 @@ class EngineAssert(object): self.assert_list = list if list is not None: self.assert_list.reverse() - def _set_echo(self, echo): self.engine.echo = echo echo = property(lambda s: s.engine.echo, _set_echo) - def execute_compiled(self, compiled, parameters, **kwargs): + def pre_exec(self, proxy, compiled, parameters, **kwargs): self.engine.logger = self.logger statement = str(compiled) statement = re.sub(r'\n', '', statement) - + if self.assert_list is not None: item = self.assert_list.pop() (query, params) = item @@ -127,7 +133,8 @@ class EngineAssert(object): query = re.sub(r':([\w_]+)', repl, query) self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - return self.realexec(compiled, parameters, **kwargs) + self.sql_count += 1 + return self.realexec(proxy, compiled, parameters, **kwargs) class TTestSuite(unittest.TestSuite): -- 2.47.2