]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rework ColumnAdapter and ORMAdapter to only provide the features
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Sep 2014 18:55:44 +0000 (14:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Sep 2014 18:55:44 +0000 (14:55 -0400)
we're now using; rework them fully so that their behavioral contract
is consistent regarding adapter.traverse() vs. adapter.columns[],
add a full suite of tests including advanced wrapping scenarios
previously only covered by test/orm/test_froms.py and
test/orm/inheritance/test_relationships.py
- identify several cases where label._order_by_label_clause would be
corrupted, e.g. due to adaption or annotation separately
- add full tests for #3148

lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/util.py
test/orm/test_query.py
test/sql/test_generative.py
test/sql/test_selectable.py

index ed2011d4e2195ae68946aa206e179f4b116fa1a7..1bb6b571e993bc744bfe83286666ac11c39f1997 100644 (file)
@@ -270,10 +270,8 @@ first()
 
 
 class ORMAdapter(sql_util.ColumnAdapter):
-    """Extends ColumnAdapter to accept ORM entities.
-
-    The selectable is extracted from the given entity,
-    and the AliasedClass if any is referenced.
+    """ColumnAdapter subclass which excludes adaptation of entities from
+    non-matching mappers.
 
     """
 
@@ -289,18 +287,18 @@ class ORMAdapter(sql_util.ColumnAdapter):
             self.aliased_class = entity
         else:
             self.aliased_class = None
+
         sql_util.ColumnAdapter.__init__(
             self, selectable, equivalents, chain_to,
             adapt_required=adapt_required,
             allow_label_resolve=allow_label_resolve,
-            anonymize_labels=anonymize_labels)
+            anonymize_labels=anonymize_labels,
+            include_fn=self._include_fn
+        )
 
-    def replace(self, elem):
+    def _include_fn(self, elem):
         entity = elem._annotations.get('parentmapper', None)
-        if not entity or entity.isa(self.mapper):
-            return sql_util.ColumnAdapter.replace(self, elem)
-        else:
-            return None
+        return not entity or entity.isa(self.mapper)
 
 
 class AliasedClass(object):
index ece6bce9e965bf608f66a975deca3897a85d8626..cf8de936df39ab731d9aa78f9fd42bc676cf2534 100644 (file)
@@ -2588,7 +2588,7 @@ class UnaryExpression(ColumnElement):
         return UnaryExpression(
             expr, operator=operators.distinct_op, type_=expr.type)
 
-    @util.memoized_property
+    @property
     def _order_by_label_element(self):
         if self.modifier in (operators.desc_op, operators.asc_op):
             return self.element._order_by_label_element
@@ -2913,7 +2913,7 @@ class Label(ColumnElement):
     def _allow_label_resolve(self):
         return self.element._allow_label_resolve
 
-    @util.memoized_property
+    @property
     def _order_by_label_element(self):
         return self
 
@@ -2949,6 +2949,7 @@ class Label(ColumnElement):
 
     def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
         self.element = clone(self.element, **kw)
+        self.__dict__.pop('_allow_label_resolve', None)
         if anonymize_labels:
             self.name = _anonymous_label(
                 '%%(%d %s)s' % (
index f630f9e935a3eb271e2672e90c32cf46fcadef96..d6f3b591594ab5248bfa59f4ec78d6fc457b1d6e 100644 (file)
@@ -428,35 +428,6 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
     return pairs
 
 
-class AliasedRow(object):
-    """Wrap a RowProxy with a translation map.
-
-    This object allows a set of keys to be translated
-    to those present in a RowProxy.
-
-    """
-
-    def __init__(self, row, map):
-        # AliasedRow objects don't nest, so un-nest
-        # if another AliasedRow was passed
-        if isinstance(row, AliasedRow):
-            self.row = row.row
-        else:
-            self.row = row
-        self.map = map
-
-    def __contains__(self, key):
-        return self.map[key] in self.row
-
-    def has_key(self, key):
-        return key in self
-
-    def __getitem__(self, key):
-        return self.row[self.map[key]]
-
-    def keys(self):
-        return self.row.keys()
-
 
 class ClauseAdapter(visitors.ReplacingCloningVisitor):
     """Clones and modifies clauses based on column correspondence.
@@ -486,23 +457,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
     """
 
     def __init__(self, selectable, equivalents=None,
-                 include=None, exclude=None,
                  include_fn=None, exclude_fn=None,
                  adapt_on_names=False, anonymize_labels=False):
         self.__traverse_options__ = {
             'stop_on': [selectable],
             'anonymize_labels': anonymize_labels}
         self.selectable = selectable
-        if include:
-            assert not include_fn
-            self.include_fn = lambda e: e in include
-        else:
-            self.include_fn = include_fn
-        if exclude:
-            assert not exclude_fn
-            self.exclude_fn = lambda e: e in exclude
-        else:
-            self.exclude_fn = exclude_fn
+        self.include_fn = include_fn
+        self.exclude_fn = exclude_fn
         self.equivalents = util.column_dict(equivalents or {})
         self.adapt_on_names = adapt_on_names
 
@@ -522,10 +484,8 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
             newcol = self.selectable.c.get(col.name)
         return newcol
 
-    magic_flag = False
-
     def replace(self, col):
-        if not self.magic_flag and isinstance(col, FromClause) and \
+        if isinstance(col, FromClause) and \
                 self.selectable.is_derived_from(col):
             return self.selectable
         elif not isinstance(col, ColumnElement):
@@ -541,72 +501,102 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
 class ColumnAdapter(ClauseAdapter):
     """Extends ClauseAdapter with extra utility functions.
 
-    Provides the ability to "wrap" this ClauseAdapter
-    around another, a columns dictionary which returns
-    adapted elements given an original, and an
-    adapted_row() factory.
+    Key aspects of ColumnAdapter include:
+
+    * Expressions that are adapted are stored in a persistent
+      .columns collection; so that an expression E adapted into
+      an expression E1, will return the same object E1 when adapted
+      a second time.   This is important in particular for things like
+      Label objects that are anonymized, so that the ColumnAdapter can
+      be used to present a consistent "adapted" view of things.
+
+    * Exclusion of items from the persistent collection based on
+      include/exclude rules, but also independent of hash identity.
+      This because "annotated" items all have the same hash identity as their
+      parent.
+
+    * "wrapping" capability is added, so that the replacement of an expression
+      E can proceed through a series of adapters.  This differs from the
+      visitor's "chaining" feature in that the resulting object is passed
+      through all replacing functions unconditionally, rather than stopping
+      at the first one that returns non-None.
+
+    * An adapt_required option, used by eager loading to indicate that
+      We don't trust a result row column that is not translated.
+      This is to prevent a column from being interpreted as that
+      of the child row in a self-referential scenario, see
+      inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
 
     """
 
     def __init__(self, selectable, equivalents=None,
-                 chain_to=None, include=None,
-                 exclude=None, adapt_required=False,
+                 chain_to=None, adapt_required=False,
+                 include_fn=None, exclude_fn=None,
                  adapt_on_names=False,
                  allow_label_resolve=True,
                  anonymize_labels=False):
         ClauseAdapter.__init__(self, selectable, equivalents,
-                               include, exclude,
+                               include_fn=include_fn, exclude_fn=exclude_fn,
                                adapt_on_names=adapt_on_names,
                                anonymize_labels=anonymize_labels)
 
         if chain_to:
             self.chain(chain_to)
         self.columns = util.populate_column_dict(self._locate_col)
+        if self.include_fn or self.exclude_fn:
+            self.columns = self._IncludeExcludeMapping(self, self.columns)
         self.adapt_required = adapt_required
         self.allow_label_resolve = allow_label_resolve
+        self._wrap = None
+
+    class _IncludeExcludeMapping(object):
+        def __init__(self, parent, columns):
+            self.parent = parent
+            self.columns = columns
+
+        def __getitem__(self, key):
+            if (
+                self.parent.include_fn and not self.parent.include_fn(key)
+            ) or (
+                self.parent.exclude_fn and self.parent.exclude_fn(key)
+            ):
+                if self.parent._wrap:
+                    return self.parent._wrap.columns[key]
+                else:
+                    return key
+            return self.columns[key]
 
     def wrap(self, adapter):
         ac = self.__class__.__new__(self.__class__)
-        ac.__dict__ = self.__dict__.copy()
-        ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
-        ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
-        ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
+        ac.__dict__.update(self.__dict__)
+        ac._wrap = adapter
         ac.columns = util.populate_column_dict(ac._locate_col)
+        if ac.include_fn or ac.exclude_fn:
+            ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
+
         return ac
 
     def traverse(self, obj):
-        new_obj = ClauseAdapter.traverse(self, obj)
-        if new_obj is not obj:
-            self.columns[obj] = new_obj
-        return new_obj
+        return self.columns[obj]
 
     adapt_clause = traverse
     adapt_list = ClauseAdapter.copy_and_process
 
-    def _wrap(self, local, wrapped):
-        def locate(col):
-            col = local(col)
-            return wrapped(col)
-        return locate
-
     def _locate_col(self, col):
-        c = self._corresponding_column(col, True)
-        if c is None:
-            c = self.adapt_clause(col)
-
-        # adapt_required used by eager loading to indicate that
-        # we don't trust a result row column that is not translated.
-        # this is to prevent a column from being interpreted as that
-        # of the child row in a self-referential scenario, see
-        # inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
+
+        c = ClauseAdapter.traverse(self, col)
+
+        if self._wrap:
+            c2 = self._wrap._locate_col(c)
+            if c2 is not None:
+                c = c2
+
         if self.adapt_required and c is col:
             return None
 
         c._allow_label_resolve = self.allow_label_resolve
-        return c
 
-    def adapted_row(self, row):
-        return AliasedRow(row, self.columns)
+        return c
 
     def __getstate__(self):
         d = self.__dict__.copy()
index a7184fe01f29a200de49f4d049f57e07a6cb8018..3f6813138247f413a8ce82949219c09ca5d71dca 100644 (file)
@@ -14,7 +14,7 @@ from sqlalchemy.testing.schema import Table, Column
 import sqlalchemy as sa
 from sqlalchemy.testing.assertions import (
     eq_, assert_raises, assert_raises_message, expect_warnings)
-from sqlalchemy.testing import fixtures, AssertsCompiledSQL
+from sqlalchemy.testing import fixtures, AssertsCompiledSQL, assert_warnings
 from test.orm import _fixtures
 from sqlalchemy.orm.util import join, with_parent
 
@@ -1236,21 +1236,23 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     __dialect__ = 'default'
     run_setup_mappers = 'each'
 
-    def _fixture(self):
+    def _fixture(self, label=True):
         User, Address = self.classes("User", "Address")
         users, addresses = self.tables("users", "addresses")
+        stmt = select([func.max(addresses.c.email_address)]).\
+            where(addresses.c.user_id == users.c.id).\
+            correlate(users)
+        if label:
+            stmt = stmt.label("email_ad")
+
         mapper(User, users, properties={
-            "ead": column_property(
-                select([func.max(addresses.c.email_address)]).\
-                    where(addresses.c.user_id == users.c.id).\
-                    correlate(users).label("email_ad")
-            )
+            "ead": column_property(stmt)
         })
         mapper(Address, addresses)
 
-    def test_order_by_column_prop_label(self):
+    def test_order_by_column_prop_string(self):
         User, Address = self.classes("User", "Address")
-        self._fixture()
+        self._fixture(label=True)
 
         s = Session()
         q = s.query(User).order_by("email_ad")
@@ -1263,9 +1265,169 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             "FROM users ORDER BY email_ad"
         )
 
-    def test_order_by_column_prop_attrname(self):
+    def test_order_by_column_prop_aliased_string(self):
+        User, Address = self.classes("User", "Address")
+        self._fixture(label=True)
+
+        s = Session()
+        ua = aliased(User)
+        q = s.query(ua).order_by("email_ad")
+
+        def go():
+            self.assert_compile(
+                q,
+                "SELECT (SELECT max(addresses.email_address) AS max_1 "
+                "FROM addresses WHERE addresses.user_id = users_1.id) "
+                "AS anon_1, users_1.id AS users_1_id, "
+                "users_1.name AS users_1_name FROM users AS users_1 "
+                "ORDER BY email_ad"
+            )
+        assert_warnings(
+            go,
+            ["Can't resolve label reference 'email_ad'"], regex=True)
+
+    def test_order_by_column_labeled_prop_attr_aliased_one(self):
+        User = self.classes.User
+        self._fixture(label=True)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(ua).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, "
+            "users_1.id AS users_1_id, users_1.name AS users_1_name "
+            "FROM users AS users_1 ORDER BY anon_1"
+        )
+
+    def test_order_by_column_labeled_prop_attr_aliased_two(self):
+        User = self.classes.User
+        self._fixture(label=True)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(ua.ead).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses, "
+            "users AS users_1 WHERE addresses.user_id = users_1.id) "
+            "AS anon_1 ORDER BY anon_1"
+        )
+
+        # we're also testing that the state of "ua" is OK after the
+        # previous call, so the batching into one test is intentional
+        q = s.query(ua).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, "
+            "users_1.id AS users_1_id, users_1.name AS users_1_name "
+            "FROM users AS users_1 ORDER BY anon_1"
+        )
+
+    def test_order_by_column_labeled_prop_attr_aliased_three(self):
+        User = self.classes.User
+        self._fixture(label=True)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(User.ead, ua.ead).order_by(User.ead, ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses, users WHERE addresses.user_id = users.id) "
+            "AS email_ad, (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses, users AS users_1 WHERE addresses.user_id = "
+            "users_1.id) AS anon_1 ORDER BY email_ad, anon_1"
+        )
+
+        q = s.query(User, ua).order_by(User.ead, ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses WHERE addresses.user_id = users.id) AS "
+            "email_ad, users.id AS users_id, users.name AS users_name, "
+            "(SELECT max(addresses.email_address) AS max_1 FROM addresses "
+            "WHERE addresses.user_id = users_1.id) AS anon_1, users_1.id "
+            "AS users_1_id, users_1.name AS users_1_name FROM users, "
+            "users AS users_1 ORDER BY email_ad, anon_1"
+        )
+
+    def test_order_by_column_unlabeled_prop_attr_aliased_one(self):
+        User = self.classes.User
+        self._fixture(label=False)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(ua).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, "
+            "users_1.id AS users_1_id, users_1.name AS users_1_name "
+            "FROM users AS users_1 ORDER BY anon_1"
+        )
+
+    def test_order_by_column_unlabeled_prop_attr_aliased_two(self):
+        User = self.classes.User
+        self._fixture(label=False)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(ua.ead).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses, "
+            "users AS users_1 WHERE addresses.user_id = users_1.id) "
+            "AS anon_1 ORDER BY anon_1"
+        )
+
+        # we're also testing that the state of "ua" is OK after the
+        # previous call, so the batching into one test is intentional
+        q = s.query(ua).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses WHERE addresses.user_id = users_1.id) AS anon_1, "
+            "users_1.id AS users_1_id, users_1.name AS users_1_name "
+            "FROM users AS users_1 ORDER BY anon_1"
+        )
+
+    def test_order_by_column_unlabeled_prop_attr_aliased_three(self):
+        User = self.classes.User
+        self._fixture(label=False)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(User.ead, ua.ead).order_by(User.ead, ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses, users WHERE addresses.user_id = users.id) "
+            "AS anon_1, (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses, users AS users_1 "
+            "WHERE addresses.user_id = users_1.id) AS anon_2 "
+            "ORDER BY anon_1, anon_2"
+        )
+
+        q = s.query(User, ua).order_by(User.ead, ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses WHERE addresses.user_id = users.id) AS "
+            "anon_1, users.id AS users_id, users.name AS users_name, "
+            "(SELECT max(addresses.email_address) AS max_1 FROM addresses "
+            "WHERE addresses.user_id = users_1.id) AS anon_2, users_1.id "
+            "AS users_1_id, users_1.name AS users_1_name FROM users, "
+            "users AS users_1 ORDER BY anon_1, anon_2"
+        )
+
+    def test_order_by_column_prop_attr(self):
         User, Address = self.classes("User", "Address")
-        self._fixture()
+        self._fixture(label=True)
 
         s = Session()
         q = s.query(User).order_by(User.ead)
@@ -1281,9 +1443,9 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             "FROM users ORDER BY email_ad"
         )
 
-    def test_order_by_column_prop_attrname_non_present(self):
+    def test_order_by_column_prop_attr_non_present(self):
         User, Address = self.classes("User", "Address")
-        self._fixture()
+        self._fixture(label=True)
 
         s = Session()
         q = s.query(User).options(defer(User.ead)).order_by(User.ead)
index 1140a1180854331a657472cebf101782daca94ca..013ba8082dd1afad1d8791e41a9a0816dfee1cf9 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy.sql.visitors import ClauseVisitor, CloningVisitor, \
     cloned_traverse, ReplacingCloningVisitor
 from sqlalchemy import exc
 from sqlalchemy.sql import util as sql_util
-from sqlalchemy.testing import eq_, is_, assert_raises, assert_raises_message
+from sqlalchemy.testing import eq_, is_, is_not_, assert_raises, assert_raises_message
 
 A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
 
@@ -696,6 +696,244 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             "AS anon_1 WHERE table1.col1 = anon_1.col1)")
 
 
+class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    @classmethod
+    def setup_class(cls):
+        global t1, t2
+        t1 = table("table1",
+                   column("col1"),
+                   column("col2"),
+                   column("col3"),
+                   column("col4")
+                   )
+        t2 = table("table2",
+                   column("col1"),
+                   column("col2"),
+                   column("col3"),
+                   )
+
+    def test_traverse_memoizes_w_columns(self):
+        t1a = t1.alias()
+        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
+
+        expr = select([t1a.c.col1]).label('x')
+        expr_adapted = adapter.traverse(expr)
+        is_not_(expr, expr_adapted)
+        is_(
+            adapter.columns[expr],
+            expr_adapted
+        )
+
+    def test_traverse_memoizes_w_itself(self):
+        t1a = t1.alias()
+        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
+
+        expr = select([t1a.c.col1]).label('x')
+        expr_adapted = adapter.traverse(expr)
+        is_not_(expr, expr_adapted)
+        is_(
+            adapter.traverse(expr),
+            expr_adapted
+        )
+
+    def test_columns_memoizes_w_itself(self):
+        t1a = t1.alias()
+        adapter = sql_util.ColumnAdapter(t1a, anonymize_labels=True)
+
+        expr = select([t1a.c.col1]).label('x')
+        expr_adapted = adapter.columns[expr]
+        is_not_(expr, expr_adapted)
+        is_(
+            adapter.columns[expr],
+            expr_adapted
+        )
+
+    def test_wrapping_fallthrough(self):
+        t1a = t1.alias(name="t1a")
+        t2a = t2.alias(name="t2a")
+        a1 = sql_util.ColumnAdapter(t1a)
+
+        s1 = select([t1a.c.col1, t2a.c.col1]).apply_labels().alias()
+        a2 = sql_util.ColumnAdapter(s1)
+        a3 = a2.wrap(a1)
+        a4 = a1.wrap(a2)
+        a5 = a1.chain(a2)
+
+        # t1.c.col1 -> s1.c.t1a_col1
+
+        # adapted by a2
+        is_(
+            a3.columns[t1.c.col1], s1.c.t1a_col1
+        )
+        is_(
+            a4.columns[t1.c.col1], s1.c.t1a_col1
+        )
+
+        # chaining can't fall through because a1 grabs it
+        # first
+        is_(
+            a5.columns[t1.c.col1], t1a.c.col1
+        )
+
+        # t2.c.col1 -> s1.c.t2a_col1
+
+        # adapted by a2
+        is_(
+            a3.columns[t2.c.col1], s1.c.t2a_col1
+        )
+        is_(
+            a4.columns[t2.c.col1], s1.c.t2a_col1
+        )
+        # chaining, t2 hits s1
+        is_(
+            a5.columns[t2.c.col1], s1.c.t2a_col1
+        )
+
+        # t1.c.col2 -> t1a.c.col2
+
+        # fallthrough to a1
+        is_(
+            a3.columns[t1.c.col2], t1a.c.col2
+        )
+        is_(
+            a4.columns[t1.c.col2], t1a.c.col2
+        )
+
+        # chaining hits a1
+        is_(
+            a5.columns[t1.c.col2], t1a.c.col2
+        )
+
+        # t2.c.col2 -> t2.c.col2
+
+        # fallthrough to no adaption
+        is_(
+            a3.columns[t2.c.col2], t2.c.col2
+        )
+        is_(
+            a4.columns[t2.c.col2], t2.c.col2
+        )
+
+    def test_wrapping_ordering(self):
+        """illustrate an example where order of wrappers matters.
+
+        This test illustrates both the ordering being significant
+        as well as a scenario where multiple translations are needed
+        (e.g. wrapping vs. chaining).
+
+        """
+
+        stmt = select([t1.c.col1, t2.c.col1]).apply_labels()
+
+        sa = stmt.alias()
+        stmt2 = select([t2, sa])
+
+        a1 = sql_util.ColumnAdapter(stmt)
+        a2 = sql_util.ColumnAdapter(stmt2)
+
+        a2_to_a1 = a2.wrap(a1)
+        a1_to_a2 = a1.wrap(a2)
+
+        # when stmt2 and stmt represent the same column
+        # in different contexts, order of wrapping matters
+
+        # t2.c.col1 via a2 is stmt2.c.col1; then ignored by a1
+        is_(
+            a2_to_a1.columns[t2.c.col1], stmt2.c.col1
+        )
+        # t2.c.col1 via a1 is stmt.c.table2_col1; a2 then
+        # sends this to stmt2.c.table2_col1
+        is_(
+            a1_to_a2.columns[t2.c.col1], stmt2.c.table2_col1
+        )
+
+        # for mutually exclusive columns, order doesn't matter
+        is_(
+            a2_to_a1.columns[t1.c.col1], stmt2.c.table1_col1
+        )
+        is_(
+            a1_to_a2.columns[t1.c.col1], stmt2.c.table1_col1
+        )
+        is_(
+            a2_to_a1.columns[t2.c.col2], stmt2.c.col2
+        )
+
+
+    def test_wrapping_multiple(self):
+        """illustrate that wrapping runs both adapters"""
+
+        t1a = t1.alias(name="t1a")
+        t2a = t2.alias(name="t2a")
+        a1 = sql_util.ColumnAdapter(t1a)
+        a2 = sql_util.ColumnAdapter(t2a)
+        a3 = a2.wrap(a1)
+
+        stmt = select([t1.c.col1, t2.c.col2])
+
+        self.assert_compile(
+            a3.traverse(stmt),
+            "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a"
+        )
+
+        # chaining does too because these adapters don't share any
+        # columns
+        a4 = a2.chain(a1)
+        self.assert_compile(
+            a4.traverse(stmt),
+            "SELECT t1a.col1, t2a.col2 FROM table1 AS t1a, table2 AS t2a"
+        )
+
+    def test_wrapping_inclusions(self):
+        """test wrapping and inclusion rules together,
+        taking into account multiple objects with equivalent hash identity."""
+
+        t1a = t1.alias(name="t1a")
+        t2a = t2.alias(name="t2a")
+        a1 = sql_util.ColumnAdapter(
+            t1a,
+            include_fn=lambda col: "a1" in col._annotations)
+
+        s1 = select([t1a, t2a]).apply_labels().alias()
+        a2 = sql_util.ColumnAdapter(
+            s1,
+            include_fn=lambda col: "a2" in col._annotations)
+        a3 = a2.wrap(a1)
+
+        c1a1 = t1.c.col1._annotate(dict(a1=True))
+        c1a2 = t1.c.col1._annotate(dict(a2=True))
+        c1aa = t1.c.col1._annotate(dict(a1=True, a2=True))
+
+        c2a1 = t2.c.col1._annotate(dict(a1=True))
+        c2a2 = t2.c.col1._annotate(dict(a2=True))
+        c2aa = t2.c.col1._annotate(dict(a1=True, a2=True))
+
+        is_(
+            a3.columns[c1a1], t1a.c.col1
+        )
+        is_(
+            a3.columns[c1a2], s1.c.t1a_col1
+        )
+        is_(
+            a3.columns[c1aa], s1.c.t1a_col1
+        )
+
+        # not covered by a1, accepted by a2
+        is_(
+            a3.columns[c2aa], s1.c.t2a_col1
+        )
+
+        # not covered by a1, accepted by a2
+        is_(
+            a3.columns[c2a2], s1.c.t2a_col1
+        )
+        # not covered by a1, rejected by a2
+        is_(
+            a3.columns[c2a1], c2a1
+        )
+
+
 class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
@@ -1022,7 +1260,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         assert str(e) == "a.id = a.xxx_id"
         b = a.alias()
 
-        e = sql_util.ClauseAdapter(b, include=set([a.c.id]),
+        e = sql_util.ClauseAdapter(b, include_fn=lambda x: x in set([a.c.id]),
                                    equivalents={a.c.id: set([a.c.id])}
                                    ).traverse(e)
 
@@ -1254,6 +1492,28 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "ORDER BY anon_1, anon_2"
         )
 
+    def test_label_anonymize_three(self):
+        t1a = t1.alias()
+        adapter = sql_util.ColumnAdapter(
+            t1a, anonymize_labels=True,
+            allow_label_resolve=False)
+
+        expr = select([t1.c.col2]).where(t1.c.col3 == 5).label(None)
+        l1 = expr
+        is_(l1._order_by_label_element, l1)
+        eq_(l1._allow_label_resolve, True)
+
+        expr_adapted = adapter.traverse(expr)
+        l2 = expr_adapted
+        is_(l2._order_by_label_element, l2)
+        eq_(l2._allow_label_resolve, False)
+
+        l3 = adapter.traverse(expr)
+        is_(l3._order_by_label_element, l3)
+        eq_(l3._allow_label_resolve, False)
+
+
+
 class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
index c5736b26f4a6a580e5f8719d52f8bad81cc30749..a3b2b0e930d39c9358f2f5642be73fae3d9c0462 100644 (file)
@@ -1724,6 +1724,13 @@ class AnnotationsTest(fixtures.TestBase):
         b5 = visitors.cloned_traverse(b3, {}, {'binary': visit_binary})
         assert str(b5) == ":bar = table1.col2"
 
+    def test_label_accessors(self):
+        t1 = table('t1', column('c1'))
+        l1 = t1.c.c1.label(None)
+        is_(l1._order_by_label_element, l1)
+        l1a = l1._annotate({"foo": "bar"})
+        is_(l1a._order_by_label_element, l1a)
+
     def test_annotate_aliased(self):
         t1 = table('t1', column('c1'))
         s = select([(t1.c.c1 + 3).label('bat')])