]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rewrote and simplified the system used to "target" columns across
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Nov 2007 00:59:19 +0000 (00:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Nov 2007 00:59:19 +0000 (00:59 +0000)
selectable expressions.  On the SQL side this is represented by the
"corresponding_column()" method. This method is used heavily by the ORM
to "adapt" elements of an expression to similar, aliased expressions,
as well as to target result set columns originally bound to a
table or selectable to an aliased, "corresponding" expression.  The new
rewrite features completely consistent and accurate behavior.
- the "orig_set" and "distance" elements as well as all associated
fanfare are gone (hooray !)
- columns now have an optional "proxies" list which is a list of all
columns they are a "proxy" for; only CompoundSelect cols proxy more than one column
(just like before).  set operations are used to determine lineage.
- CompoundSelects (i.e. unions) only create one public-facing proxy column per
column name.  primary key collections come out with just one column per embedded
PK column.
- made the alias used by eager load limited subquery anonymous.

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
test/orm/inheritance/polymorph2.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index 0a1a65730f3a8e022044bf49ba60408ebe4f7405..95ff2a4108b8cfb9fcd1099a9c1ce4b0fbd967c3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -29,6 +29,14 @@ CHANGES
 
   - func. objects can be pickled/unpickled [ticket:844]
 
+  - rewrote and simplified the system used to "target" columns across
+    selectable expressions.  On the SQL side this is represented by the
+    "corresponding_column()" method. This method is used heavily by the ORM
+    to "adapt" elements of an expression to similar, aliased expressions,
+    as well as to target result set columns originally bound to a 
+    table or selectable to an aliased, "corresponding" expression.  The new
+    rewrite features completely consistent and accurate behavior.
+    
 - orm
   - eager loading with LIMIT/OFFSET applied no longer adds the primary 
     table joined to a limited subquery of itself; the eager loads now
index ac0dc83ab61d8c809537f32aada6c10c57abd3a4..2029dd3bee2ec3e0d51f2921106e30fde2e429ad 100644 (file)
@@ -882,7 +882,7 @@ class Query(object):
             if order_by:
                 s2.append_order_by(*util.to_list(order_by))
             
-            s3 = s2.alias('primary_tbl_limited')
+            s3 = s2.alias()
                 
             self._primary_adapter = mapperutil.create_row_adapter(s3, self.table)
 
index 29a28e54b6f0c97cddb3c3d5c22706c4a320e3cd..5ddca718a86ce774233659a85be82ead879a5230 100644 (file)
@@ -449,7 +449,6 @@ class Column(SchemaItem, expression._ColumnClause):
         self.onupdate = kwargs.pop('onupdate', None)
         self.autoincrement = kwargs.pop('autoincrement', True)
         self.constraints = util.Set()
-        self.__originating_column = self
         self._foreign_keys = util.OrderedSet()
         if kwargs:
             raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys()))
@@ -554,9 +553,7 @@ class Column(SchemaItem, expression._ColumnClause):
         fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
         c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
         c.table = selectable
-        c.orig_set = self.orig_set
-        c.__originating_column = self.__originating_column
-        c._distance = self._distance + 1
+        c.proxies = [self]
         c._pre_existing_column = self._pre_existing_column
         if not c._is_oid:
             selectable.columns.add(c)
@@ -635,10 +632,8 @@ class ForeignKey(SchemaItem):
                 # locate the parent table this foreign key is attached to.
                 # we use the "original" column which our parent column represents
                 # (its a list of columns/other ColumnElements if the parent table is a UNION)
-                for c in self.parent.orig_set:
-                    if isinstance(c, Column):
-                        parenttable = c.table
-                        break
+                if isinstance(self.parent.base_column, Column):
+                    parenttable = self.parent.base_column.table
                 else:
                     raise exceptions.ArgumentError("Parent column '%s' does not descend from a table-attached Column" % str(self.parent))
                 m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec, re.UNICODE)
index 22c296e98bcee5cb3941fd14b82b4616ae8d86bb..51bd176c3783514c4f2ce05025c361f7743f8df4 100644 (file)
@@ -1392,36 +1392,32 @@ class ColumnElement(ClauseElement, _CompareMixin):
             return None
 
     foreign_key = property(_one_fkey)
-
-    def _get_orig_set(self):
-        try:
-            return self.__orig_set
-        except AttributeError:
-            self.__orig_set = util.Set([self])
-            return self.__orig_set
-
-    def _set_orig_set(self, s):
-        if len(s) == 0:
-            s.add(self)
-        self.__orig_set = s
-
-    orig_set = property(_get_orig_set, _set_orig_set,
-                        doc=\
-        """A Set containing TableClause-bound, non-proxied ColumnElements 
-        for which this ColumnElement is a proxy.  In all cases except 
-        for a column proxied from a Union (i.e. CompoundSelect), this 
-        set will be just one element.
-        """)
-
+    
+    def base_column(self):
+        if hasattr(self, '_base_column'):
+            return self._base_column
+        p = self
+        while hasattr(p, 'proxies'):
+            p = p.proxies[0]
+        self._base_column = p
+        return p
+    base_column = property(base_column)
+    
+    def proxy_set(self):
+        if hasattr(self, '_proxy_set'):
+            return self._proxy_set
+        s = util.Set([self])
+        if hasattr(self, 'proxies'):
+            for c in self.proxies:
+                s = s.union(c.proxy_set)
+        self._proxy_set = s
+        return s
+    proxy_set = property(proxy_set)
+    
     def shares_lineage(self, othercolumn):
         """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``.
         """
-
-        for c in self.orig_set:
-            if c in othercolumn.orig_set:
-                return True
-        else:
-            return False
+        return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0
 
     def _make_proxy(self, selectable, name=None):
         """Create a new ``ColumnElement`` representing this
@@ -1434,7 +1430,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
 
         if name is not None:
             co = _ColumnClause(name, selectable)
-            co.orig_set = self.orig_set
+            co.proxies = [self]
             selectable.columns[name]= co
             return co
         else:
@@ -1569,14 +1565,6 @@ class FromClause(Selectable):
 
         return False
 
-    def _get_all_embedded_columns(self):
-        ret = []
-        class FindCols(visitors.ClauseVisitor):
-            def visit_column(self, col):
-                ret.append(col)
-        FindCols().traverse(self)
-        return ret
-
     def is_derived_from(self, fromclause):
         """Return True if this FromClause is 'derived' from the given FromClause.
 
@@ -1616,19 +1604,20 @@ class FromClause(Selectable):
           of this ``FromClause``.
         """
 
-        if self.c.contains_column(column):
-            return column
-
-        if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
+        if require_embedded and column not in self._get_all_embedded_columns():
             if not raiseerr:
                 return None
             else:
-                raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table))
-        for c in column.orig_set:
-            try:
-                return self.original_columns[c]
-            except KeyError:
-                pass
+                raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table.description))
+
+        col, intersect = None, None
+        target_set = column.proxy_set
+        for c in self.c + [self.oid_column]:
+            i = c.proxy_set.intersection(target_set)
+            if i and (intersect is None or len(i) > len(intersect)):
+                col, intersect = c, i
+        if col:
+            return col
         else:
             if keys_ok:
                 try:
@@ -1638,18 +1627,33 @@ class FromClause(Selectable):
             if not raiseerr:
                 return None
             else:
-                raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.name))
+                raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), self.description))
 
+    def description(self):
+        return getattr(self, 'name', self.__class__.__name__ + " object")
+    description = property(description)
+    
     def _clone_from_clause(self):
         # delete all the "generated" collections of columns for a
         # newly cloned FromClause, so that they will be re-derived
         # from the item.  this is because FromClause subclasses, when
         # cloned, need to reestablish new "proxied" columns that are
         # linked to the new item
-        for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'):
+        for attr in ('_columns', '_primary_key' '_foreign_keys', '_oid_column', '_embedded_columns'):
             if hasattr(self, attr):
                 delattr(self, attr)
 
+    def _get_all_embedded_columns(self):
+        if hasattr(self, '_embedded_columns'):
+            return self._embedded_columns
+        ret = util.Set()
+        class FindCols(visitors.ClauseVisitor):
+            def visit_column(self, col):
+                ret.add(col)
+        FindCols().traverse(self)
+        self._embedded_columns = ret
+        return ret
+
     def _expr_attr_func(name):
         def attr(self):
             try:
@@ -1663,22 +1667,10 @@ class FromClause(Selectable):
     c = property(_expr_attr_func('_columns'))
     primary_key = property(_expr_attr_func('_primary_key'))
     foreign_keys = property(_expr_attr_func('_foreign_keys'))
-    original_columns = property(_expr_attr_func('_orig_cols'), doc=\
-        """A dictionary mapping an original Table-bound 
-        column to a proxied column in this FromClause.
-        """)
 
     def _export_columns(self, columns=None):
         """Initialize column collections.
 
-        The collections include the primary key, foreign keys, list of
-        all columns, as well as the *_orig_cols* collection which is a
-        dictionary used to match Table-bound columns to proxied
-        columns in this ``FromClause``.  The columns in each
-        collection are *proxied* from the columns returned by the
-        _exportable_columns method, where a *proxied* column maintains
-        most or all of the properties of its original column, except
-        its parent ``Selectable`` is this ``FromClause``.
         """
 
         if hasattr(self, '_columns') and columns is None:
@@ -1686,24 +1678,11 @@ class FromClause(Selectable):
         self._columns = ColumnCollection()
         self._primary_key = ColumnSet()
         self._foreign_keys = util.Set()
-        self._orig_cols = {}
 
         if columns is None:
             columns = self._flatten_exportable_columns()
         for co in columns:
             cp = self._proxy_column(co)
-            for ci in cp.orig_set:
-                cx = self._orig_cols.get(ci)
-                # TODO: the '=' thing here relates to the order of
-                # columns as they are placed in the "columns"
-                # collection of a CompositeSelect, illustrated in
-                # test/sql/selectable.SelectableTest.testunion make
-                # this relationship less brittle
-                if cx is None or cp._distance <= cx._distance:
-                    self._orig_cols[ci] = cp
-        if self.oid_column is not None:
-            for ci in self.oid_column.orig_set:
-                self._orig_cols[ci] = self.oid_column
 
     def _flatten_exportable_columns(self):
         """Return the list of ColumnElements represented within this FromClause's _exportable_columns"""
@@ -2058,7 +2037,6 @@ class _Cast(ColumnElement):
         self.type = sqltypes.to_instance(totype)
         self.clause = clause
         self.typeclause = _TypeClause(self.type)
-        self._distance = 0
 
     def _copy_internals(self, clone=_clone):
         self.clause = clone(self.clause)
@@ -2073,8 +2051,7 @@ class _Cast(ColumnElement):
     def _make_proxy(self, selectable, name=None):
         if name is not None:
             co = _ColumnClause(name, selectable, type_=self.type)
-            co._distance = self._distance + 1
-            co.orig_set = self.orig_set
+            co.proxies = [self]
             selectable.columns[name]= co
             return co
         else:
@@ -2251,6 +2228,10 @@ class Join(FromClause):
 
         self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
 
+    def description(self):
+        return "Join object on %s and %s" % (self.left.description, self.right.description)
+    description = property(description)
+    
     primary_key = property(lambda s:s.__primary_key)
 
     def self_group(self, against=None):
@@ -2294,14 +2275,14 @@ class Join(FromClause):
         if len(crit) == 0:
             raise exceptions.ArgumentError(
                 "Can't find any foreign key relationships "
-                "between '%s' and '%s'" % (primary.name, secondary.name))
+                "between '%s' and '%s'" % (primary.description, secondary.description))
         elif len(constraints) > 1:
             raise exceptions.ArgumentError(
                 "Can't determine join between '%s' and '%s'; "
                 "tables have more than one foreign key "
                 "constraint relationship between them. "
                 "Please specify the 'onclause' of this "
-                "join explicitly." % (primary.name, secondary.name))
+                "join explicitly." % (primary.description, secondary.description))
         elif len(crit) == 1:
             return (crit[0])
         else:
@@ -2456,7 +2437,6 @@ class _ColumnElementAdapter(ColumnElement):
     def __init__(self, elem):
         self.elem = elem
         self.type = getattr(elem, 'type', None)
-        self.orig_set = getattr(elem, 'orig_set', util.Set())
 
     key = property(lambda s: s.elem.key)
     _label = property(lambda s: s.elem._label)
@@ -2477,12 +2457,11 @@ class _ColumnElementAdapter(ColumnElement):
         return getattr(self.elem, attr)
 
     def __getstate__(self):
-        return {'elem':self.elem, 'type':self.type, 'orig_set':self.orig_set
+        return {'elem':self.elem, 'type':self.type} 
 
     def __setstate__(self, state):
         self.elem = state['elem']
         self.type = state['type']
-        self.orig_set = state['orig_set']
 
 class _Grouping(_ColumnElementAdapter):
     """Represent a grouping within a column expression"""
@@ -2527,14 +2506,15 @@ class _Label(ColumnElement):
         while isinstance(obj, _Label):
             obj = obj.obj
         self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
-
         self.obj = obj.self_group(against=operators.as_)
         self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
 
     key = property(lambda s: s.name)
     _label = property(lambda s: s.name)
-    orig_set = property(lambda s:s.obj.orig_set)
-
+    proxies = property(lambda s:s.obj.proxies)
+    base_column = property(lambda s:s.obj.base_column)
+    proxy_set = property(lambda s:s.obj.proxy_set)
+    
     def expression_element(self):
         return self.obj
 
@@ -2589,7 +2569,6 @@ class _ColumnClause(ColumnElement):
         self.table = selectable
         self.type = sqltypes.to_instance(type_)
         self._is_oid = _is_oid
-        self._distance = 0
         self.__label = None
         self.is_literal = is_literal
 
@@ -2621,8 +2600,6 @@ class _ColumnClause(ColumnElement):
                 self.__label = self.name
         return self.__label
 
-    is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
-
     _label = property(_get_label)
 
     def label(self, name):
@@ -2647,8 +2624,7 @@ class _ColumnClause(ColumnElement):
         # otherwise its considered to be a label
         is_literal = self.is_literal and (name is None or name == self.name)
         c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
-        c.orig_set = self.orig_set
-        c._distance = self._distance + 1
+        c.proxies = [self]
         if not self._is_oid:
             selectable.columns[c.name] = c
         return c
@@ -2686,18 +2662,6 @@ class TableClause(FromClause):
         self.append_column(c)
         return c
 
-    def _orig_columns(self):
-        try:
-            return self._orig_cols
-        except AttributeError:
-            self._orig_cols= {}
-            for c in self.columns:
-                for ci in c.orig_set:
-                    self._orig_cols[ci] = c
-            return self._orig_cols
-
-    original_columns = property(_orig_columns)
-
     def get_children(self, column_collections=True, **kwargs):
         if column_collections:
             return [c for c in self.c]
@@ -2922,18 +2886,17 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
                 yield c
 
     def _proxy_column(self, column):
-        if self.use_labels:
-            col = column._make_proxy(self, name=column._label)
+        existing = self._col_map.get(column.name, None)
+        if existing is not None:
+            existing.proxies.append(column)
+            return existing
         else:
-            col = column._make_proxy(self)
-        try:
-            colset = self._col_map[col.name]
-        except KeyError:
-            colset = util.Set()
-            self._col_map[col.name] = colset
-        [colset.add(c) for c in col.orig_set]
-        col.orig_set = colset
-        return col
+            if self.use_labels:
+                col = column._make_proxy(self, name=column._label)
+            else:
+                col = column._make_proxy(self)
+            self._col_map[col.name] = col
+            return col
 
     def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
index 6acb75dd9caeabe22f33b2e3569b3034726c6477..117bf7031f43690d6dab625da1cdb125dbb51ccd 100644 (file)
@@ -893,10 +893,6 @@ class CustomPKTest(ORMTest):
         d['t2'] = t1.join(t2)
         pjoin = polymorphic_union(d, None, 'pjoin')
         
-        #print pjoin.original.primary_key
-        #print pjoin.primary_key
-        assert len(pjoin.primary_key) == 2
-        
         mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', select_table=pjoin, primary_key=[pjoin.c.id])
         mapper(T2, t2, inherits=T1, polymorphic_identity='t2')
         print [str(c) for c in class_mapper(T1).primary_key]
@@ -932,10 +928,6 @@ class CustomPKTest(ORMTest):
         d['t2'] = t1.join(t2)
         pjoin = polymorphic_union(d, None, 'pjoin')
 
-        #print pjoin.original.primary_key
-        #print pjoin.primary_key
-        assert len(pjoin.primary_key) == 2
-
         mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', select_table=pjoin)
         mapper(T2, t2, inherits=T1, polymorphic_identity='t2')
         assert len(class_mapper(T1).primary_key) == 1
index 9fca3ee08ef09083ebdbbf7376d28a1dfc2744a3..aa04e2936c81abc13f7939fe5fb3333e57f3549b 100755 (executable)
@@ -185,8 +185,8 @@ class SelectableTest(AssertMixin):
         print j4
         print j4.corresponding_column(j2.c.aid)
         print j4.c.aid
-        # TODO: this is the assertion case which fails
-#        assert j4.corresponding_column(j2.c.aid) is j4.c.aid
+        assert j4.corresponding_column(j2.c.aid) is j4.c.aid
+        assert j4.corresponding_column(a.c.id) is j4.c.id
 
 class PrimaryKeyTest(AssertMixin):
     def test_join_pk_collapse_implicit(self):