]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reworking concept of column lists, "FromObject", "Selectable";
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Dec 2005 00:27:46 +0000 (00:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Dec 2005 00:27:46 +0000 (00:27 +0000)
support for types to be propigated into boolean expressions;
new label() function/method to make any column/literal/function/bind param
into a "foo AS bar" clause, better support in ansisql for this concept;
trying to get column list on a select() object to be Column and ColumnClause
objects equally, working on mappers that map to those select() objects

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/sql.py
test/mapper.py
test/types.py

index 7a90e746a093eaa78e0e0cfeee538882be082cba..8df5e535231f86f36f7abb09853aac9b3060f34b 100644 (file)
@@ -152,15 +152,19 @@ class ANSICompiler(sql.Compiled):
             return p
         else:
             return parameters
-            
+    
+    def visit_label(self, label):
+        if len(self.select_stack):
+            self.typemap.setdefault(label.name.lower(), label.obj.type)
+            if label.obj.type is None:
+                raise "nonetype" + repr(label.obj)
+        self.strings[label] = self.strings[label.obj] + " AS "  + label.name
+        
     def visit_column(self, column):
         if len(self.select_stack):
             # if we are within a visit to a Select, set up the "typemap"
             # for this column which is used to translate result set values
-            if self.select_stack[-1].use_labels:
-                self.typemap.setdefault(column.label.lower(), column.type)
-            else:
-                self.typemap.setdefault(column.key.lower(), column.type)
+            self.typemap.setdefault(column.key.lower(), column.type)
         if column.table.name is None:
             self.strings[column] = column.name
         else:
@@ -249,27 +253,24 @@ class ANSICompiler(sql.Compiled):
         # its an ordered dictionary to insure that the actual labeled column name
         # is unique.
         inner_columns = OrderedDict()
-        def col_key(c):
-            if select.use_labels:
-                return c.label
-            else:
-                return self.get_str(c)
-                
+
         self.select_stack.append(select)
         for c in select._raw_columns:
             if c.is_selectable():
                 for co in c.columns:
-                    co.accept_visitor(self)
-                    inner_columns[col_key(co)] = co
+                    if select.use_labels:
+                        l = co.label(co._label)
+                        l.accept_visitor(self)
+                        inner_columns[co._label] = l
+                    else:
+                        co.accept_visitor(self)
+                        inner_columns[self.get_str(co)] = co
             else:
                 c.accept_visitor(self)
-                inner_columns[col_key(c)] = c
+                inner_columns[self.get_str(c)] = c
         self.select_stack.pop(-1)
         
-        if select.use_labels:
-            collist = string.join(["%s AS %s" % (self.get_str(v), k) for k, v in inner_columns.iteritems()], ', ')
-        else:
-            collist = string.join([k for k in inner_columns.keys()], ', ')
+        collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
 
         text = "SELECT "
         if select.distinct:
@@ -287,8 +288,6 @@ class ANSICompiler(sql.Compiled):
             for c in inner_columns.values():
                 if self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
                     value = self.parameters[c.key]
-                elif self.parameters.has_key(c.label) and not self.binds.has_key(c.label):
-                    value = self.parameters[c.label]
                 else:
                     continue
                 clause = c==value
index 4c39200cbee57ce45fda012945eda93bc7ec7d29..24c67411d1c293ad64b6838b4de780c84dbc1bd6 100644 (file)
@@ -619,15 +619,17 @@ class ResultProxy:
                     rec = (typemap.get(colname, types.NULLTYPE), i)
                 else:
                     rec = (types.NULLTYPE, i)
+                if rec[0] is None:
+                    raise "None for metadata " + colname
                 if self.props.setdefault(colname, rec) is not rec:
                     self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0)
                 self.props[i] = rec
                 i+=1
 
     def _get_col(self, row, key):
-        if isinstance(key, schema.Column):
+        if isinstance(key, schema.Column) or isinstance(key, sql.ColumnElement):
             try:
-                rec = self.props[key.label.lower()]
+                rec = self.props[key._label.lower()]
             except KeyError:
                 try:
                     rec = self.props[key.key.lower()]
index 336a32e6f53b21458deb914b60215a9c329c081f..41616dceb83b8a94cc9332e10c6d835d2dbcd87b 100644 (file)
@@ -122,10 +122,10 @@ class Mapper(object):
         # load custom properties 
         if properties is not None:
             for key, prop in properties.iteritems():
-                if isinstance(prop, schema.Column):
+                if isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement):
                     self.columns[key] = prop
                     prop = ColumnProperty(prop)
-                elif isinstance(prop, list) and isinstance(prop[0], schema.Column):
+                elif isinstance(prop, list) and (isinstance(prop[0], schema.Column) or isinstance(prop[0], sql.ColumnElement)) :
                     self.columns[key] = prop[0]
                     prop = ColumnProperty(*prop)
                 self.props[key] = prop
@@ -172,7 +172,7 @@ class Mapper(object):
 
     def add_property(self, key, prop):
         self.copyargs['properties'][key] = prop
-        if isinstance(prop, schema.Column):
+        if (isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement)):
             self.columns[key] = prop
             prop = ColumnProperty(prop)
         self.props[key] = prop
@@ -581,6 +581,10 @@ class Mapper(object):
         if not no_sort:
             if self.order_by:
                 order_by = self.order_by
+#            elif self.table.rowid_column is not None:
+ #               order_by = self.table.rowid_column
+  #          else:
+  #              order_by = None
             else:
                 order_by = self.table.rowid_column
         else:
index 699175353bfbea1b9e564d8c9b6a04277bdb2f07..b0e86259a1e8cca578840c25519240ad150ee1e4 100644 (file)
@@ -162,6 +162,10 @@ def literal(value, type=None):
     """
     return BindParamClause('literal', value, type=type)
 
+def label(name, obj):
+    """returns a Label object for the given selectable, used in the column list for a select statement."""
+    return Label(name, obj)
+    
 def column(table, text):
     """returns a textual column clause, relative to a table.  this differs from using straight text
     or text() in that the column is treated like a regular column, i.e. gets added to a Selectable's list
@@ -224,7 +228,8 @@ class ClauseVisitor(schema.SchemaVisitor):
     def visit_null(self, null):pass
     def visit_clauselist(self, list):pass
     def visit_function(self, func):pass
-    
+    def visit_label(self, label):pass
+        
 class Compiled(ClauseVisitor):
     """represents a compiled SQL expression.  the __str__ method of the Compiled object
     should produce the actual text of the statement.  Compiled objects are specific to the
@@ -336,12 +341,6 @@ class ClauseElement(object):
             
     engine = property(lambda s: s._find_engine())
 
-    def _get_columns(self):
-        try:
-            return self._columns
-        except AttributeError:
-            return [self]
-    columns = property(lambda s: s._get_columns())
     
     def compile(self, engine = None, parameters = None, typemap=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
@@ -419,6 +418,8 @@ class CompareMixin(object):
         return self._compare('LIKE', str(other) + "%")
     def endswith(self, other):
         return self._compare('LIKE', "%" + str(other))
+    def label(self, name):
+        return Label(name, self)
     # and here come the math operators:
     def __add__(self, other):
         return self._compare('+', other)
@@ -441,10 +442,47 @@ class CompareMixin(object):
             else:
                 obj = self._bind_param(obj)
 
-        return BinaryClause(self, obj, operator)
+        return BinaryClause(self, obj, operator, type=obj.type)
+
+class Selectable(ClauseElement):
+    """represents a column list-holding object."""
+
+    def _get_columns(self):
+        try:
+            return self._columns
+        except AttributeError:
+            return [self]
+    columns = property(lambda s: s._get_columns())
+    c = property(lambda self: self.columns)
+
+    def accept_visitor(self, visitor):
+        raise NotImplementedError()
+
+    def is_selectable(self):
+        return True
+
+    def select(self, whereclauses = None, **params):
+        return select([self], whereclauses, **params)
+
+    def _get_col_by_original(self, column):
+        """given a column which is a schema.Column object attached to a schema.Table object
+        (i.e. an "original" column), return the Column object from this 
+        Selectable which corresponds to that original Column, or None if this Selectable
+        does not contain the column."""
+        raise NotImplementedError()
 
-class FromClause(ClauseElement):
-    """represents an element within the FROM clause of a SELECT statement."""
+    def _group_parenthesized(self):
+        """indicates if this Selectable requires parenthesis when grouped into a compound
+        statement"""
+        return True
+
+class ColumnElement(Selectable, CompareMixin):
+    """represents a column element within the list of a Selectable's columns."""
+    primary_key = property(lambda s:False)
+    original = property(lambda self:self)
+
+class FromClause(Selectable):
+    """represents an element that can be used within the FROM clause of a SELECT statement."""
     def __init__(self, from_name = None, from_key = None):
         self.from_name = from_name
         self.id = from_key or from_name
@@ -455,6 +493,13 @@ class FromClause(ClauseElement):
         return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name))
     def accept_visitor(self, visitor): 
         visitor.visit_fromclause(self)
+    def join(self, right, *args, **kwargs):
+        return Join(self, right, *args, **kwargs)
+    def outerjoin(self, right, *args, **kwargs):
+        return Join(self, right, isouter = True, *args, **kwargs)
+    def alias(self, name):
+        return Alias(self, name)
+
     
 class BindParamClause(ClauseElement, CompareMixin):
     """represents a bind parameter.  public constructor is the bindparam() function."""
@@ -557,12 +602,12 @@ class CompoundClause(ClauseList):
     def hash_key(self):
         return string.join([c.hash_key() for c in self.clauses], self.operator or " ")
 
+    
 class Function(ClauseList, CompareMixin):
     """describes a SQL function. extends ClauseList to provide comparison operators."""
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
-        self.type = kwargs.get('type', None)
-        self.label = kwargs.get('label', None)
+        self.type = kwargs.get('type', types.NULLTYPE)
         ClauseList.__init__(self, parens=True, *clauses)
     key = property(lambda self:self.label or self.name)
     def append(self, clause):
@@ -595,10 +640,11 @@ class Function(ClauseList, CompareMixin):
         
 class BinaryClause(ClauseElement, CompareMixin):
     """represents two clauses with an operator in between"""
-    def __init__(self, left, right, operator):
+    def __init__(self, left, right, operator, type=None):
         self.left = left
         self.right = right
         self.operator = operator
+        self.type = type
         self.parens = False
     def copy_container(self):
         return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator)
@@ -614,42 +660,10 @@ class BinaryClause(ClauseElement, CompareMixin):
         c = self.left
         self.left = self.right
         self.right = c
-        
-class Selectable(FromClause):
-    """represents a column list-holding object, like a table, alias or subquery.  can be used anywhere a Table is used."""
-    
-    c = property(lambda self: self.columns)
-
-    def accept_visitor(self, visitor):
-        raise NotImplementedError()
-    
-    def is_selectable(self):
-        return True
-        
-    def select(self, whereclauses = None, **params):
-        return select([self], whereclauses, **params)
-
-    def _get_col_by_original(self, column):
-        """given a column which is a schema.Column object attached to a schema.Table object
-        (i.e. an "original" column), return the Column object from this 
-        Selectable which corresponds to that original Column, or None if this Selectable
-        does not contain the column."""
-        raise NotImplementedError()
-
-    def join(self, right, *args, **kwargs):
-        return Join(self, right, *args, **kwargs)
 
-    def outerjoin(self, right, *args, **kwargs):
-        return Join(self, right, isouter = True, *args, **kwargs)
 
-    def alias(self, name):
-        return Alias(self, name)
-    def _group_parenthesized(self):
-        """indicates if this Selectable requires parenthesis when grouped into a compound
-        statement"""
-        return True
         
-class Join(Selectable):
+class Join(FromClause):
     # TODO: put "using" + "natural" concepts in here and make "onclause" optional
     def __init__(self, left, right, onclause=None, isouter = False, allcols = True):
         self.left = left
@@ -731,7 +745,7 @@ class Join(Selectable):
     def _get_from_objects(self):
         return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
         
-class Alias(Selectable):
+class Alias(FromClause):
     def __init__(self, selectable, alias = None):
         self.selectable = selectable
         self._columns = util.OrderedProperties()
@@ -765,20 +779,39 @@ class Alias(Selectable):
         
     engine = property(lambda s: s.selectable.engine)
 
-class ColumnClause(Selectable, CompareMixin):
+    
+class Label(ColumnElement):
+    def __init__(self, name, obj):
+        self.name = name
+        while isinstance(obj, Label):
+            obj = obj.obj
+        self.obj = obj
+        obj.parens=True
+    key = property(lambda s: s.name)
+    _label = property(lambda s: s.name)
+    def accept_visitor(self, visitor):
+        self.obj.accept_visitor(visitor)
+        visitor.visit_label(self)
+    def _get_from_objects(self):
+        return self.obj._get_from_objects()
+    def _make_proxy(self, selectable, name = None):
+        cc = ColumnClause(self.name)
+        selectable.c[self.name] = cc
+        return cc
+     
+class ColumnClause(ColumnElement):
     """represents a textual column clause in a SQL statement. allows the creation
     of an additional ad-hoc column that is compiled against a particular table."""
 
     def __init__(self, text, selectable=None):
         self.text = text
         self.table = selectable
-        self._impl = ColumnImpl(self)
         self.type = types.NullTypeEngine()
 
     name = property(lambda self:self.text)
     key = property(lambda self:self.text)
-    label = property(lambda self:self.text)
-
+    _label = property(lambda self:self.text)
+    
     def accept_visitor(self, visitor): 
         visitor.visit_columnclause(self)
 
@@ -807,11 +840,10 @@ class ColumnClause(Selectable, CompareMixin):
     def _make_proxy(self, selectable, name = None):
         c = ColumnClause(self.text or name, selectable)
         selectable.columns[c.key] = c
-        c._impl = ColumnImpl(c)
         return c
 
-class ColumnImpl(Selectable, CompareMixin):
-    """Selectable implementation that gets attached to a schema.Column object."""
+class ColumnImpl(ColumnElement):
+    """gets attached to a schema.Column object."""
     
     def __init__(self, column):
         self.column = column
@@ -819,12 +851,15 @@ class ColumnImpl(Selectable, CompareMixin):
         self._columns = [self.column]
         
         if column.table.name:
-            self.label = column.table.name + "_" + self.column.name
+            self._label = column.table.name + "_" + self.column.name
         else:
-            self.label = self.column.name
+            self._label = self.column.name
 
     engine = property(lambda s: s.column.engine)
     
+    def label(self, name):
+        return Label(name, self.column)
+        
     def copy_container(self):
         return self.column
 
@@ -855,9 +890,9 @@ class ColumnImpl(Selectable, CompareMixin):
             else:
                 obj = self._bind_param(obj)
 
-        return BinaryClause(self.column, obj, operator)
+        return BinaryClause(self.column, obj, operator, type=self.column.type)
 
-class TableImpl(Selectable):
+class TableImpl(FromClause):
     """attached to a schema.Table to provide it with a Selectable interface
     as well as other functions
     """
@@ -943,7 +978,7 @@ class SelectBaseMixin(object):
         else:
             return [self]
             
-class CompoundSelect(SelectBaseMixin, Selectable):
+class CompoundSelect(SelectBaseMixin, FromClause):
     def __init__(self, keyword, *selects, **kwargs):
         self.id = "Compound(%d)" % id(self)
         self.keyword = keyword
@@ -976,7 +1011,7 @@ class CompoundSelect(SelectBaseMixin, Selectable):
         else:
             return None
        
-class Select(SelectBaseMixin, Selectable):
+class Select(SelectBaseMixin, FromClause):
     """represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
     def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None):
@@ -1047,23 +1082,32 @@ class Select(SelectBaseMixin, Selectable):
 
         for f in column._get_from_objects():
             f.accept_visitor(self._correlator)
-            if self.rowid_column is None and hasattr(f, 'rowid_column'):
+            if self.rowid_column is None and hasattr(f, 'rowid_column') and f.rowid_column is not None:
                 self.rowid_column = f.rowid_column._make_proxy(self)
         column._process_from_dict(self._froms, False)
 
         if column.is_selectable():
             for co in column.columns:
                 if self.use_labels:
-                    co._make_proxy(self, name = co.label)
+                    co._make_proxy(self, name = co._label)
                 else:
                     co._make_proxy(self)
-
+            
     def _get_col_by_original(self, column):
         if self.use_labels:
-            return self.columns.get(column.label,None)
+            return self.columns.get(column._label,None)
         else:
             return self.columns.get(column.key,None)
 
+    def _pks(self):
+        ret = {}
+        for from_obj in self._get_froms():
+            for c in from_obj.c:
+                if c.primary_key:
+                    ret[c] = c
+        return ret.keys()
+    primary_key = property (lambda self: self._pks())
+
     def append_whereclause(self, whereclause):
         self._append_condition('whereclause', whereclause)
     def append_having(self, having):
@@ -1077,6 +1121,9 @@ class Select(SelectBaseMixin, Selectable):
             setattr(self, attribute, and_(getattr(self, attribute), condition))
         else:
             setattr(self, attribute, condition)
+
+    def hash_key(self):
+        return "Select(%d)" % (id(self))
     
     def clear_from(self, id):
         self.append_from(FromClause(from_name = None, from_key = id))
index a18ec467d79e476a660d183950238d96e86c4263..90d182b6a7338fc7be06d5aaecec80cef5682aae 100644 (file)
@@ -118,6 +118,14 @@ class MapperTest(MapperSuperTest):
 #        l = m.select(order_by=[])
 #        l = m.select(order_by=None)
         
+        
+    def testfunction(self):
+        s = select([users, (users.c.user_id * 2).label('concat'), func.count(users.c.user_id).label('count')], group_by=[c for c in users.c], use_labels=True)
+        m = mapper(User, s.alias('test'))
+        l = m.select()
+        print [repr(x.__dict__) for x in l]
+        
+        
     def testmultitable(self):
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
         m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id])
index 26e8dfd10fcaa9a0a3a8621e21399e12e98393ec..70cc1500fc54c6fc3645dfee54f1ce4d9a0cb68a 100644 (file)
@@ -36,6 +36,10 @@ class TypesTest(testbase.PersistTest):
         l = users.select().execute().fetchall()
         print repr(l)
         self.assert_(l == [(2, u'BIND_INjackBIND_OUT'), (3, u'BIND_INlalaBIND_OUT'), (4, u'BIND_INfredBIND_OUT')])
+
+        l = users.select(use_labels=True).execute().fetchall()
+        print repr(l)
+        self.assert_(l == [(2, u'BIND_INjackBIND_OUT'), (3, u'BIND_INlalaBIND_OUT'), (4, u'BIND_INfredBIND_OUT')])
      
         
 if __name__ == "__main__":