self.strings[alias] = self.get_str(alias.selectable)
def visit_select(self, select):
- inner_columns = []
-
+
+ # the actual list of columns to print in the SELECT column list.
+ # its an ordered dictionary to insure that the actual labeled column name
+ # is unique.
+ inner_columns = OrderedDict()
+ def col_key(c):
+ if select.use_labels:
+ return c.label
+ else:
+ return self.get_str(c)
+
self.select_stack.append(select)
for c in select._raw_columns:
if c.is_selectable():
for co in c.columns:
co.accept_visitor(self)
- inner_columns.append(co)
+ inner_columns[col_key(co)] = co
else:
c.accept_visitor(self)
- inner_columns.append(c)
+ inner_columns[col_key(c)] = c
self.select_stack.pop(-1)
if select.use_labels:
- collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ')
+ collist = string.join(["%s AS %s" % (self.get_str(v), k) for k, v in inner_columns.iteritems()], ', ')
else:
- collist = string.join([self.get_str(c) for c in inner_columns], ', ')
+ collist = string.join([k for k in inner_columns.keys()], ', ')
text = "SELECT "
if select.distinct:
# matching those keys
if self.parameters is not None:
revisit = False
- for c in inner_columns:
+ for c in inner_columns.values():
if self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
value = self.parameters[c.key]
elif self.parameters.has_key(c.label) and not self.binds.has_key(c.label):
c.default.accept_visitor(vis)
self.isinsert = True
- colparams = insert_stmt.get_colparams(self.parameters)
+ colparams = self._get_colparams(insert_stmt)
for c in colparams:
b = c[1]
self.binds[b.key] = b
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
- colparams = update_stmt.get_colparams(self.parameters)
+ colparams = self._get_colparams(update_stmt)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p
self.strings[update_stmt] = text
+
+ def _get_colparams(self, stmt):
+ """determines the VALUES or SET clause for an INSERT or UPDATE
+ clause based on the arguments specified to this ANSICompiler object
+ (i.e., the execute() or compile() method clause object):
+
+ insert(mytable).execute(col1='foo', col2='bar')
+ mytable.update().execute(col2='foo', col3='bar')
+
+ in the above examples, the insert() and update() methods have no "values" sent to them
+ at all, so compiling them with no arguments would yield an insert for all table columns,
+ or an update with no SET clauses. but the parameters sent indicate a set of per-compilation
+ arguments that result in a differently compiled INSERT or UPDATE object compared to the
+ original. The "values" parameter to the insert/update is figured as well if present,
+ but the incoming "parameters" sent here take precedence.
+ """
+ # case one: no parameters in the statement, no parameters in the
+ # compiled params - just return binds for all the table columns
+ if self.parameters is None and stmt.parameters is None:
+ return [(c, bindparam(c.name, type=c.type)) for c in stmt.table.columns]
+
+ # if we have statement parameters - set defaults in the
+ # compiled params
+ if self.parameters is None:
+ parameters = {}
+ else:
+ parameters = self.parameters.copy()
+
+ if stmt.parameters is not None:
+ for k, v in stmt.parameters.iteritems():
+ parameters.setdefault(k, v)
+
+ # now go thru compiled params, get the Column object for each key
+ d = {}
+ for key, value in parameters.iteritems():
+ if isinstance(key, schema.Column):
+ d[key] = value
+ else:
+ try:
+ d[stmt.table.columns[str(key)]] = value
+ except KeyError:
+ pass
+
+ # create a list of column assignment clauses as tuples
+ values = []
+ for c in stmt.table.columns:
+ if d.has_key(c):
+ value = d[c]
+ if sql._is_literal(value):
+ value = bindparam(c.name, value, type=c.type)
+ values.append((c, value))
+ return values
+
def visit_delete(self, delete_stmt):
text = "DELETE FROM " + delete_stmt.table.fullname
# along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
-
"""defines the base components of SQL expression trees."""
import sqlalchemy.schema as schema
def execute(self, *multiparams, **params):
"""executes this compiled object using the underlying SQLEngine"""
if len(multiparams):
- params = [self.get_params(**m) for m in multiparams]
- else:
- params = self.get_params(**params)
-
+ params = multiparams
+
return self.engine.execute_compiled(self, params)
def scalar(self, *multiparams, **params):
return BinaryClause(self, obj, operator)
class FromClause(ClauseElement):
- """represents a FROM clause element in a SQL statement."""
-
+ """represents an element within the FROM clause of a SELECT statement."""
def __init__(self, from_name = None, from_key = None):
self.from_name = from_name
self.id = from_key or from_name
-
def _get_from_objects(self):
# this could also be [self], at the moment it doesnt matter to the Select object
return []
-
def hash_key(self):
return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name))
-
def accept_visitor(self, visitor):
visitor.visit_fromclause(self)
class BindParamClause(ClauseElement, CompareMixin):
+ """represents a bind parameter. public constructor is the bindparam() function."""
def __init__(self, key, value, shortname = None, type = None):
self.key = key
self.value = value
self.shortname = shortname
self.type = type or types.NULLTYPE
-
def accept_visitor(self, visitor):
visitor.visit_bindparam(self)
-
def _get_from_objects(self):
return []
-
def hash_key(self):
return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname))
-
def typeprocess(self, value):
return self.type.convert_bind_param(value)
class TextClause(ClauseElement):
- """represents literal text, including SQL fragments as well
- as literal (non bind-param) values."""
+ """represents literal a SQL text fragment. public constructor is the
+ text() function.
- def __init__(self, text = "", engine=None, isliteral=False):
+ TextClauses, since they can be anything, have no comparison operators or
+ typing information.
+
+ A single literal value within a compiled SQL statement is more useful
+ being specified as a bind parameter via the bindparam() method,
+ since it provides more information about what it is, including an optional
+ type, as well as providing comparison operations."""
+ def __init__(self, text = "", engine=None):
self.text = text
self.parens = False
self._engine = engine
self.id = id(self)
- if isliteral:
- if isinstance(text, int) or isinstance(text, long):
- self.text = str(text)
- else:
- text = re.sub(r"'", r"''", text)
- self.text = "'" + text + "'"
def accept_visitor(self, visitor):
visitor.visit_textclause(self)
def hash_key(self):
return []
class Null(ClauseElement):
+ """represents the NULL keyword in a SQL statement. public contstructor is the
+ null() function."""
def accept_visitor(self, visitor):
visitor.visit_null(self)
def _get_from_objects(self):
self._rowid_column._set_parent(table)
rowid_column = property(lambda s: s._rowid_column)
-
engine = property(lambda s: s.table.engine)
+ columns = property(lambda self: self.table.columns)
def _get_col_by_original(self, column):
try:
def join(self, right, *args, **kwargs):
return Join(self.table, right, *args, **kwargs)
-
def outerjoin(self, right, *args, **kwargs):
return Join(self.table, right, isouter = True, *args, **kwargs)
-
def alias(self, name):
return Alias(self.table, name)
-
def select(self, whereclause = None, **params):
return select([self.table], whereclause, **params)
-
def insert(self, values = None):
return insert(self.table, values=values)
-
def update(self, whereclause = None, values = None):
return update(self.table, whereclause, values)
-
def delete(self, whereclause = None):
return delete(self.table, whereclause)
-
- columns = property(lambda self: self.table.columns)
-
- def _get_from_objects(self):
- return [self.table]
-
def create(self, **params):
self.table.engine.create(self.table)
-
def drop(self, **params):
self.table.engine.drop(self.table)
+ def _get_from_objects(self):
+ return [self.table]
class SelectBaseMixin(object):
"""base class for Select and CompoundSelects"""
froms = property(lambda s: s._get_froms())
def accept_visitor(self, visitor):
+ # TODO: add contextual visit_ methods
+ # visit_select_whereclause, visit_select_froms, visit_select_orderby, etc.
+ # which will allow the compiler to set contextual flags before traversing
+ # into each thing.
for f in self._get_froms():
f.accept_visitor(visitor)
if self.whereclause is not None:
self._engine = e
return e
return None
-
-
class UpdateBase(ClauseElement):
- """forms the base for INSERT, UPDATE, and DELETE statements.
- Deals with the special needs of INSERT and UPDATE parameter lists -
- these statements have two separate lists of parameters, those
- defined when the statement is constructed, and those specified at compile time."""
+ """forms the base for INSERT, UPDATE, and DELETE statements."""
def _process_colparams(self, parameters):
+ """receives the "values" of an INSERT or UPDATE statement and constructs
+ appropriate ind parameters."""
if parameters is None:
return None
del parameters[key]
return parameters
- def get_colparams(self, parameters):
- """this is used by the ANSICompiler to determine the VALUES or SET clause based on the arguments
- specified to the execute() or compile() method of the INSERT or UPDATE clause:
-
- insert(mytable).execute(col1='foo', col2='bar')
- mytable.update().execute(col2='foo', col3='bar')
-
- in the above examples, the insert() and update() methods have no "values" sent to them
- at all, so compiling them with no arguments would yield an insert for all table columns,
- or an update with no SET clauses. but the parameters sent indicate a set of per-compilation
- arguments that result in a differently compiled INSERT or UPDATE object compared to the
- original. The "values" parameter to the insert/update is figured as well if present,
- but the incoming "parameters" sent here take precedence.
- """
- # case one: no parameters in the statement, no parameters in the
- # compiled params - just return binds for all the table columns
- if parameters is None and self.parameters is None:
- return [(c, bindparam(c.name, type=c.type)) for c in self.table.columns]
-
- # if we have statement parameters - set defaults in the
- # compiled params
- if parameters is None:
- parameters = {}
- else:
- parameters = parameters.copy()
-
- if self.parameters is not None:
- for k, v in self.parameters.iteritems():
- parameters.setdefault(k, v)
-
- # now go thru compiled params, get the Column object for each key
- d = {}
- for key, value in parameters.iteritems():
- if isinstance(key, schema.Column):
- d[key] = value
- else:
- try:
- d[self.table.columns[str(key)]] = value
- except KeyError:
- pass
-
- # create a list of column assignment clauses as tuples
- values = []
- for c in self.table.columns:
- if d.has_key(c):
- value = d[c]
- if _is_literal(value):
- value = bindparam(c.name, value, type=c.type)
- values.append((c, value))
- return values
-
class Insert(UpdateBase):
def __init__(self, table, values=None, **params):
callable_()
finally:
db.set_assert_list(None, None)
+ def assert_sql_count(self, db, callable_, count):
+ db.sql_count = 0
+ try:
+ callable_()
+ finally:
+ self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count))
class EngineAssert(object):
"""decorates a SQLEngine object to match the incoming queries against a set of assertions."""
def __init__(self, engine):
self.engine = engine
- self.realexec = engine.execute_compiled
- engine.execute_compiled = self.execute_compiled
+ self.realexec = engine.pre_exec
+ engine.pre_exec = self.pre_exec
self.logger = engine.logger
self.set_assert_list(None, None)
+ self.sql_count = 0
def __getattr__(self, key):
return getattr(self.engine, key)
def set_assert_list(self, unittest, list):
self.assert_list = list
if list is not None:
self.assert_list.reverse()
-
def _set_echo(self, echo):
self.engine.echo = echo
echo = property(lambda s: s.engine.echo, _set_echo)
- def execute_compiled(self, compiled, parameters, **kwargs):
+ def pre_exec(self, proxy, compiled, parameters, **kwargs):
self.engine.logger = self.logger
statement = str(compiled)
statement = re.sub(r'\n', '', statement)
-
+
if self.assert_list is not None:
item = self.assert_list.pop()
(query, params) = item
query = re.sub(r':([\w_]+)', repl, query)
self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- return self.realexec(compiled, parameters, **kwargs)
+ self.sql_count += 1
+ return self.realexec(proxy, compiled, parameters, **kwargs)
class TTestSuite(unittest.TestSuite):