From ee40adf5f8905190960f0639e9c5203d854794eb Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 30 Dec 2005 00:27:46 +0000 Subject: [PATCH] reworking concept of column lists, "FromObject", "Selectable"; support for types to be propigated into boolean expressions; new label() function/method to make any column/literal/function/bind param into a "foo AS bar" clause, better support in ansisql for this concept; trying to get column list on a select() object to be Column and ColumnClause objects equally, working on mappers that map to those select() objects --- lib/sqlalchemy/ansisql.py | 39 ++++--- lib/sqlalchemy/engine.py | 6 +- lib/sqlalchemy/mapping/mapper.py | 10 +- lib/sqlalchemy/sql.py | 175 ++++++++++++++++++++----------- test/mapper.py | 8 ++ test/types.py | 4 + 6 files changed, 153 insertions(+), 89 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 7a90e746a0..8df5e53523 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -152,15 +152,19 @@ class ANSICompiler(sql.Compiled): return p else: return parameters - + + def visit_label(self, label): + if len(self.select_stack): + self.typemap.setdefault(label.name.lower(), label.obj.type) + if label.obj.type is None: + raise "nonetype" + repr(label.obj) + self.strings[label] = self.strings[label.obj] + " AS " + label.name + def visit_column(self, column): if len(self.select_stack): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values - if self.select_stack[-1].use_labels: - self.typemap.setdefault(column.label.lower(), column.type) - else: - self.typemap.setdefault(column.key.lower(), column.type) + self.typemap.setdefault(column.key.lower(), column.type) if column.table.name is None: self.strings[column] = column.name else: @@ -249,27 +253,24 @@ class ANSICompiler(sql.Compiled): # its an ordered dictionary to insure that the actual labeled column name # is unique. inner_columns = OrderedDict() - def col_key(c): - if select.use_labels: - return c.label - else: - return self.get_str(c) - + self.select_stack.append(select) for c in select._raw_columns: if c.is_selectable(): for co in c.columns: - co.accept_visitor(self) - inner_columns[col_key(co)] = co + if select.use_labels: + l = co.label(co._label) + l.accept_visitor(self) + inner_columns[co._label] = l + else: + co.accept_visitor(self) + inner_columns[self.get_str(co)] = co else: c.accept_visitor(self) - inner_columns[col_key(c)] = c + inner_columns[self.get_str(c)] = c self.select_stack.pop(-1) - if select.use_labels: - collist = string.join(["%s AS %s" % (self.get_str(v), k) for k, v in inner_columns.iteritems()], ', ') - else: - collist = string.join([k for k in inner_columns.keys()], ', ') + collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ') text = "SELECT " if select.distinct: @@ -287,8 +288,6 @@ class ANSICompiler(sql.Compiled): for c in inner_columns.values(): if self.parameters.has_key(c.key) and not self.binds.has_key(c.key): value = self.parameters[c.key] - elif self.parameters.has_key(c.label) and not self.binds.has_key(c.label): - value = self.parameters[c.label] else: continue clause = c==value diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 4c39200cbe..24c67411d1 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -619,15 +619,17 @@ class ResultProxy: rec = (typemap.get(colname, types.NULLTYPE), i) else: rec = (types.NULLTYPE, i) + if rec[0] is None: + raise "None for metadata " + colname if self.props.setdefault(colname, rec) is not rec: self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0) self.props[i] = rec i+=1 def _get_col(self, row, key): - if isinstance(key, schema.Column): + if isinstance(key, schema.Column) or isinstance(key, sql.ColumnElement): try: - rec = self.props[key.label.lower()] + rec = self.props[key._label.lower()] except KeyError: try: rec = self.props[key.key.lower()] diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 336a32e6f5..41616dceb8 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -122,10 +122,10 @@ class Mapper(object): # load custom properties if properties is not None: for key, prop in properties.iteritems(): - if isinstance(prop, schema.Column): + if isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement): self.columns[key] = prop prop = ColumnProperty(prop) - elif isinstance(prop, list) and isinstance(prop[0], schema.Column): + elif isinstance(prop, list) and (isinstance(prop[0], schema.Column) or isinstance(prop[0], sql.ColumnElement)) : self.columns[key] = prop[0] prop = ColumnProperty(*prop) self.props[key] = prop @@ -172,7 +172,7 @@ class Mapper(object): def add_property(self, key, prop): self.copyargs['properties'][key] = prop - if isinstance(prop, schema.Column): + if (isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement)): self.columns[key] = prop prop = ColumnProperty(prop) self.props[key] = prop @@ -581,6 +581,10 @@ class Mapper(object): if not no_sort: if self.order_by: order_by = self.order_by +# elif self.table.rowid_column is not None: + # order_by = self.table.rowid_column + # else: + # order_by = None else: order_by = self.table.rowid_column else: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 699175353b..b0e86259a1 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -162,6 +162,10 @@ def literal(value, type=None): """ return BindParamClause('literal', value, type=type) +def label(name, obj): + """returns a Label object for the given selectable, used in the column list for a select statement.""" + return Label(name, obj) + def column(table, text): """returns a textual column clause, relative to a table. this differs from using straight text or text() in that the column is treated like a regular column, i.e. gets added to a Selectable's list @@ -224,7 +228,8 @@ class ClauseVisitor(schema.SchemaVisitor): def visit_null(self, null):pass def visit_clauselist(self, list):pass def visit_function(self, func):pass - + def visit_label(self, label):pass + class Compiled(ClauseVisitor): """represents a compiled SQL expression. the __str__ method of the Compiled object should produce the actual text of the statement. Compiled objects are specific to the @@ -336,12 +341,6 @@ class ClauseElement(object): engine = property(lambda s: s._find_engine()) - def _get_columns(self): - try: - return self._columns - except AttributeError: - return [self] - columns = property(lambda s: s._get_columns()) def compile(self, engine = None, parameters = None, typemap=None): """compiles this SQL expression using its underlying SQLEngine to produce @@ -419,6 +418,8 @@ class CompareMixin(object): return self._compare('LIKE', str(other) + "%") def endswith(self, other): return self._compare('LIKE', "%" + str(other)) + def label(self, name): + return Label(name, self) # and here come the math operators: def __add__(self, other): return self._compare('+', other) @@ -441,10 +442,47 @@ class CompareMixin(object): else: obj = self._bind_param(obj) - return BinaryClause(self, obj, operator) + return BinaryClause(self, obj, operator, type=obj.type) + +class Selectable(ClauseElement): + """represents a column list-holding object.""" + + def _get_columns(self): + try: + return self._columns + except AttributeError: + return [self] + columns = property(lambda s: s._get_columns()) + c = property(lambda self: self.columns) + + def accept_visitor(self, visitor): + raise NotImplementedError() + + def is_selectable(self): + return True + + def select(self, whereclauses = None, **params): + return select([self], whereclauses, **params) + + def _get_col_by_original(self, column): + """given a column which is a schema.Column object attached to a schema.Table object + (i.e. an "original" column), return the Column object from this + Selectable which corresponds to that original Column, or None if this Selectable + does not contain the column.""" + raise NotImplementedError() -class FromClause(ClauseElement): - """represents an element within the FROM clause of a SELECT statement.""" + def _group_parenthesized(self): + """indicates if this Selectable requires parenthesis when grouped into a compound + statement""" + return True + +class ColumnElement(Selectable, CompareMixin): + """represents a column element within the list of a Selectable's columns.""" + primary_key = property(lambda s:False) + original = property(lambda self:self) + +class FromClause(Selectable): + """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, from_name = None, from_key = None): self.from_name = from_name self.id = from_key or from_name @@ -455,6 +493,13 @@ class FromClause(ClauseElement): return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name)) def accept_visitor(self, visitor): visitor.visit_fromclause(self) + def join(self, right, *args, **kwargs): + return Join(self, right, *args, **kwargs) + def outerjoin(self, right, *args, **kwargs): + return Join(self, right, isouter = True, *args, **kwargs) + def alias(self, name): + return Alias(self, name) + class BindParamClause(ClauseElement, CompareMixin): """represents a bind parameter. public constructor is the bindparam() function.""" @@ -557,12 +602,12 @@ class CompoundClause(ClauseList): def hash_key(self): return string.join([c.hash_key() for c in self.clauses], self.operator or " ") + class Function(ClauseList, CompareMixin): """describes a SQL function. extends ClauseList to provide comparison operators.""" def __init__(self, name, *clauses, **kwargs): self.name = name - self.type = kwargs.get('type', None) - self.label = kwargs.get('label', None) + self.type = kwargs.get('type', types.NULLTYPE) ClauseList.__init__(self, parens=True, *clauses) key = property(lambda self:self.label or self.name) def append(self, clause): @@ -595,10 +640,11 @@ class Function(ClauseList, CompareMixin): class BinaryClause(ClauseElement, CompareMixin): """represents two clauses with an operator in between""" - def __init__(self, left, right, operator): + def __init__(self, left, right, operator, type=None): self.left = left self.right = right self.operator = operator + self.type = type self.parens = False def copy_container(self): return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator) @@ -614,42 +660,10 @@ class BinaryClause(ClauseElement, CompareMixin): c = self.left self.left = self.right self.right = c - -class Selectable(FromClause): - """represents a column list-holding object, like a table, alias or subquery. can be used anywhere a Table is used.""" - - c = property(lambda self: self.columns) - - def accept_visitor(self, visitor): - raise NotImplementedError() - - def is_selectable(self): - return True - - def select(self, whereclauses = None, **params): - return select([self], whereclauses, **params) - - def _get_col_by_original(self, column): - """given a column which is a schema.Column object attached to a schema.Table object - (i.e. an "original" column), return the Column object from this - Selectable which corresponds to that original Column, or None if this Selectable - does not contain the column.""" - raise NotImplementedError() - - def join(self, right, *args, **kwargs): - return Join(self, right, *args, **kwargs) - def outerjoin(self, right, *args, **kwargs): - return Join(self, right, isouter = True, *args, **kwargs) - def alias(self, name): - return Alias(self, name) - def _group_parenthesized(self): - """indicates if this Selectable requires parenthesis when grouped into a compound - statement""" - return True -class Join(Selectable): +class Join(FromClause): # TODO: put "using" + "natural" concepts in here and make "onclause" optional def __init__(self, left, right, onclause=None, isouter = False, allcols = True): self.left = left @@ -731,7 +745,7 @@ class Join(Selectable): def _get_from_objects(self): return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects() -class Alias(Selectable): +class Alias(FromClause): def __init__(self, selectable, alias = None): self.selectable = selectable self._columns = util.OrderedProperties() @@ -765,20 +779,39 @@ class Alias(Selectable): engine = property(lambda s: s.selectable.engine) -class ColumnClause(Selectable, CompareMixin): + +class Label(ColumnElement): + def __init__(self, name, obj): + self.name = name + while isinstance(obj, Label): + obj = obj.obj + self.obj = obj + obj.parens=True + key = property(lambda s: s.name) + _label = property(lambda s: s.name) + def accept_visitor(self, visitor): + self.obj.accept_visitor(visitor) + visitor.visit_label(self) + def _get_from_objects(self): + return self.obj._get_from_objects() + def _make_proxy(self, selectable, name = None): + cc = ColumnClause(self.name) + selectable.c[self.name] = cc + return cc + +class ColumnClause(ColumnElement): """represents a textual column clause in a SQL statement. allows the creation of an additional ad-hoc column that is compiled against a particular table.""" def __init__(self, text, selectable=None): self.text = text self.table = selectable - self._impl = ColumnImpl(self) self.type = types.NullTypeEngine() name = property(lambda self:self.text) key = property(lambda self:self.text) - label = property(lambda self:self.text) - + _label = property(lambda self:self.text) + def accept_visitor(self, visitor): visitor.visit_columnclause(self) @@ -807,11 +840,10 @@ class ColumnClause(Selectable, CompareMixin): def _make_proxy(self, selectable, name = None): c = ColumnClause(self.text or name, selectable) selectable.columns[c.key] = c - c._impl = ColumnImpl(c) return c -class ColumnImpl(Selectable, CompareMixin): - """Selectable implementation that gets attached to a schema.Column object.""" +class ColumnImpl(ColumnElement): + """gets attached to a schema.Column object.""" def __init__(self, column): self.column = column @@ -819,12 +851,15 @@ class ColumnImpl(Selectable, CompareMixin): self._columns = [self.column] if column.table.name: - self.label = column.table.name + "_" + self.column.name + self._label = column.table.name + "_" + self.column.name else: - self.label = self.column.name + self._label = self.column.name engine = property(lambda s: s.column.engine) + def label(self, name): + return Label(name, self.column) + def copy_container(self): return self.column @@ -855,9 +890,9 @@ class ColumnImpl(Selectable, CompareMixin): else: obj = self._bind_param(obj) - return BinaryClause(self.column, obj, operator) + return BinaryClause(self.column, obj, operator, type=self.column.type) -class TableImpl(Selectable): +class TableImpl(FromClause): """attached to a schema.Table to provide it with a Selectable interface as well as other functions """ @@ -943,7 +978,7 @@ class SelectBaseMixin(object): else: return [self] -class CompoundSelect(SelectBaseMixin, Selectable): +class CompoundSelect(SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): self.id = "Compound(%d)" % id(self) self.keyword = keyword @@ -976,7 +1011,7 @@ class CompoundSelect(SelectBaseMixin, Selectable): else: return None -class Select(SelectBaseMixin, Selectable): +class Select(SelectBaseMixin, FromClause): """represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None): @@ -1047,23 +1082,32 @@ class Select(SelectBaseMixin, Selectable): for f in column._get_from_objects(): f.accept_visitor(self._correlator) - if self.rowid_column is None and hasattr(f, 'rowid_column'): + if self.rowid_column is None and hasattr(f, 'rowid_column') and f.rowid_column is not None: self.rowid_column = f.rowid_column._make_proxy(self) column._process_from_dict(self._froms, False) if column.is_selectable(): for co in column.columns: if self.use_labels: - co._make_proxy(self, name = co.label) + co._make_proxy(self, name = co._label) else: co._make_proxy(self) - + def _get_col_by_original(self, column): if self.use_labels: - return self.columns.get(column.label,None) + return self.columns.get(column._label,None) else: return self.columns.get(column.key,None) + def _pks(self): + ret = {} + for from_obj in self._get_froms(): + for c in from_obj.c: + if c.primary_key: + ret[c] = c + return ret.keys() + primary_key = property (lambda self: self._pks()) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) def append_having(self, having): @@ -1077,6 +1121,9 @@ class Select(SelectBaseMixin, Selectable): setattr(self, attribute, and_(getattr(self, attribute), condition)) else: setattr(self, attribute, condition) + + def hash_key(self): + return "Select(%d)" % (id(self)) def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) diff --git a/test/mapper.py b/test/mapper.py index a18ec467d7..90d182b6a7 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -118,6 +118,14 @@ class MapperTest(MapperSuperTest): # l = m.select(order_by=[]) # l = m.select(order_by=None) + + def testfunction(self): + s = select([users, (users.c.user_id * 2).label('concat'), func.count(users.c.user_id).label('count')], group_by=[c for c in users.c], use_labels=True) + m = mapper(User, s.alias('test')) + l = m.select() + print [repr(x.__dict__) for x in l] + + def testmultitable(self): usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id]) diff --git a/test/types.py b/test/types.py index 26e8dfd10f..70cc1500fc 100644 --- a/test/types.py +++ b/test/types.py @@ -36,6 +36,10 @@ class TypesTest(testbase.PersistTest): l = users.select().execute().fetchall() print repr(l) self.assert_(l == [(2, u'BIND_INjackBIND_OUT'), (3, u'BIND_INlalaBIND_OUT'), (4, u'BIND_INfredBIND_OUT')]) + + l = users.select(use_labels=True).execute().fetchall() + print repr(l) + self.assert_(l == [(2, u'BIND_INjackBIND_OUT'), (3, u'BIND_INlalaBIND_OUT'), (4, u'BIND_INfredBIND_OUT')]) if __name__ == "__main__": -- 2.47.2