]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
wip for #3148
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Sep 2014 21:56:53 +0000 (17:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Sep 2014 21:56:53 +0000 (17:56 -0400)
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/util.py
test/orm/test_eager_relations.py
test/orm/test_query.py
test/sql/test_text.py

index 2159d9135075e32d4ad4e774df7b19337c4913ee..84dd6b0456993892be3b38a5e879b6e90f6e308b 100644 (file)
@@ -1242,7 +1242,7 @@ class JoinedLoader(AbstractRelationshipLoader):
         clauses = orm_util.ORMAdapter(
             to_adapt,
             equivalents=self.mapper._equivalent_columns,
-            adapt_required=True)
+            adapt_required=True, allow_label_resolve=False)
         assert clauses.aliased_class is not None
 
         if self.parent_property.direction != interfaces.MANYTOONE:
index ea7bfc2948bd7efe297ad5941e25dc84711ed0a6..3072d6ffb6bab5aa6abad74771b669a2062e9463 100644 (file)
@@ -278,7 +278,7 @@ class ORMAdapter(sql_util.ColumnAdapter):
     """
 
     def __init__(self, entity, equivalents=None, adapt_required=False,
-                 chain_to=None):
+                 chain_to=None, allow_label_resolve=True):
         info = inspection.inspect(entity)
 
         self.mapper = info.mapper
@@ -288,9 +288,10 @@ class ORMAdapter(sql_util.ColumnAdapter):
             self.aliased_class = entity
         else:
             self.aliased_class = None
-        sql_util.ColumnAdapter.__init__(self, selectable,
-                                        equivalents, chain_to,
-                                        adapt_required=adapt_required)
+        sql_util.ColumnAdapter.__init__(
+            self, selectable, equivalents, chain_to,
+            adapt_required=adapt_required,
+            allow_label_resolve=allow_label_resolve)
 
     def replace(self, elem):
         entity = elem._annotations.get('parentmapper', None)
index af0fff826bb5a3970ab5328468a34ce0c2737429..4349c97f423239e28080531b998fc9757e9f4480 100644 (file)
@@ -512,7 +512,7 @@ class SQLCompiler(Compiled):
 
         selectable = self.stack[-1]['selectable']
         try:
-            col = selectable._inner_column_dict[element.text]
+            col = selectable._label_resolve_dict[element.text]
         except KeyError:
             # treat it like text()
             util.warn_limited(
@@ -701,6 +701,10 @@ class SQLCompiler(Compiled):
         # here; we can only add a label in the ORDER BY for an individual
         # label expression in the columns clause.
 
+        # TODO: we should see if we can bring _resolve_label
+        # into this
+
+
         raw_col = set(l._order_by_label_element.name
                       for l in order_by_select._raw_columns
                       if l._order_by_label_element is not None)
index 870e96437b3851f52db43aef452edb9dd52e5501..c8504f21f2c8e60fa9ba963f6812c4dabad06e1f 100644 (file)
@@ -675,6 +675,19 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
 
     """
 
+    _resolve_label = None
+    """The name that should be used to identify this ColumnElement in a
+    select() object when "label resolution" logic is used; this refers
+    to using a string name in an expression like order_by() or group_by()
+    that wishes to target a labeled expression in the columns clause.
+
+    The name is distinct from that of .name or ._label to account for the case
+    where anonymizing logic may be used to change the name that's actually
+    rendered at compile time; this attribute should hold onto the original
+    name that was user-assigned when producing a .label() construct.
+
+    """
+
     _alt_names = ()
 
     def self_group(self, against=None):
@@ -691,6 +704,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
         else:
             return super(ColumnElement, self)._negate()
 
+    _allow_label_resolve = True
+
     @util.memoized_property
     def type(self):
         return type_api.NULLTYPE
@@ -1231,7 +1246,7 @@ class TextClause(Executable, ClauseElement):
 
     # help in those cases where text() is
     # interpreted in a column expression situation
-    key = _label = None
+    key = _label = _resolve_label = None
 
     def __init__(
             self,
@@ -2869,8 +2884,13 @@ class Label(ColumnElement):
         :param obj: a :class:`.ColumnElement`.
 
         """
+
+        if isinstance(element, Label):
+            self._resolve_label = element._label
+
         while isinstance(element, Label):
             element = element.element
+
         if name:
             self.name = name
         else:
@@ -2885,6 +2905,10 @@ class Label(ColumnElement):
     def __reduce__(self):
         return self.__class__, (self.name, self._element, self._type)
 
+    @util.memoized_property
+    def _allow_label_resolve(self):
+        return self.element._allow_label_resolve
+
     @util.memoized_property
     def _order_by_label_element(self):
         return self
index b802a694400120dfc8a67182805ac98acdf6d409..57b16f45f36c5da5b17d92e783a81ec09cccae0a 100644 (file)
@@ -1814,7 +1814,7 @@ class GenerativeSelect(SelectBase):
                 *clauses, _literal_as_text=_literal_as_label_reference)
 
     @property
-    def _inner_column_dict(self):
+    def _label_resolve_dict(self):
         raise NotImplementedError()
 
     def _copy_internals(self, clone=_clone, **kw):
@@ -1884,7 +1884,7 @@ class CompoundSelect(GenerativeSelect):
         GenerativeSelect.__init__(self, **kwargs)
 
     @property
-    def _inner_column_dict(self):
+    def _label_resolve_dict(self):
         return dict(
             (c.key, c) for c in self.c
         )
@@ -2498,11 +2498,14 @@ class Select(HasPrefixes, GenerativeSelect):
         return _select_iterables(self._raw_columns)
 
     @_memoized_property
-    def _inner_column_dict(self):
+    def _label_resolve_dict(self):
         d = dict(
-            (c._label or c.key, c)
-            for c in _select_iterables(self._raw_columns))
-        d.update((c.key, c) for c in _select_iterables(self.froms))
+            (c._resolve_label or c._label or c.key, c)
+            for c in _select_iterables(self._raw_columns)
+            if c._allow_label_resolve)
+        d.update(
+            (c.key, c) for c in
+            _select_iterables(self.froms) if c._allow_label_resolve)
 
         return d
 
index 8bbae8b93b1f161f6b3f338b986f66ada937207f..47ab61fdd6bbbe522342651f2c9402532a88c522 100644 (file)
@@ -548,13 +548,15 @@ class ColumnAdapter(ClauseAdapter):
 
     def __init__(self, selectable, equivalents=None,
                  chain_to=None, include=None,
-                 exclude=None, adapt_required=False):
+                 exclude=None, adapt_required=False,
+                 allow_label_resolve=True):
         ClauseAdapter.__init__(self, selectable, equivalents,
                                include, exclude)
         if chain_to:
             self.chain(chain_to)
         self.columns = util.populate_column_dict(self._locate_col)
         self.adapt_required = adapt_required
+        self.allow_label_resolve = allow_label_resolve
 
     def wrap(self, adapter):
         ac = self.__class__.__new__(self.__class__)
@@ -580,6 +582,7 @@ class ColumnAdapter(ClauseAdapter):
             c = self.adapt_clause(col)
 
             # anonymize labels in case they have a hardcoded name
+            # see test_eager_relations.py -> SubqueryTest.test_label_anonymizing
             if isinstance(c, Label):
                 c = c.label(None)
 
@@ -591,6 +594,7 @@ class ColumnAdapter(ClauseAdapter):
         if self.adapt_required and c is col:
             return None
 
+        c._allow_label_resolve = self.allow_label_resolve
         return c
 
     def adapted_row(self, row):
index b0c203bf189892dbcfe4a842c4af14948c099833..214b592b5f38a40bf6d865b355ee8155668d632c 100644 (file)
@@ -14,7 +14,7 @@ from sqlalchemy.orm import mapper, relationship, create_session, \
 from sqlalchemy.sql import operators
 from sqlalchemy.testing import assert_raises, assert_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import fixtures, expect_warnings
 from test.orm import _fixtures
 from sqlalchemy.util import OrderedDict as odict
 import datetime
@@ -210,6 +210,55 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             User(id=10, addresses=[])
             ], sess.query(User).order_by(User.id).all())
 
+    def test_no_ad_hoc_orderby(self):
+        """part of #2992; make sure string label references can't
+        access an eager loader, else an eager load can corrupt the query.
+
+        """
+        Address, addresses, users, User = (self.classes.Address,
+                                           self.tables.addresses,
+                                           self.tables.users,
+                                           self.classes.User)
+
+        mapper(Address, addresses)
+        mapper(User, users, properties=dict(
+            addresses=relationship(
+                Address),
+        ))
+
+        sess = create_session()
+        q = sess.query(User).\
+            join("addresses").\
+            options(joinedload("addresses")).\
+            order_by("email_address")
+
+        self.assert_compile(
+            q,
+            "SELECT users.id AS users_id, users.name AS users_name, "
+            "addresses_1.id AS addresses_1_id, addresses_1.user_id AS "
+            "addresses_1_user_id, addresses_1.email_address AS "
+            "addresses_1_email_address FROM users JOIN addresses "
+            "ON users.id = addresses.user_id LEFT OUTER JOIN addresses "
+            "AS addresses_1 ON users.id = addresses_1.user_id "
+            "ORDER BY addresses.email_address"
+        )
+
+        q = sess.query(User).options(joinedload("addresses")).\
+            order_by("email_address")
+
+        with expect_warnings("Can't resolve label reference 'email_address'"):
+            self.assert_compile(
+                q,
+                "SELECT users.id AS users_id, users.name AS users_name, "
+                "addresses_1.id AS addresses_1_id, addresses_1.user_id AS "
+                "addresses_1_user_id, addresses_1.email_address AS "
+                "addresses_1_email_address FROM users LEFT OUTER JOIN "
+                "addresses AS addresses_1 ON users.id = addresses_1.user_id "
+                "ORDER BY email_address"
+            )
+
+
+
     def test_deferred_fk_col(self):
         users, Dingaling, User, dingalings, Address, addresses = (
             self.tables.users,
index f0470e17297be31434df0a75d87a6c8a6813338e..a7184fe01f29a200de49f4d049f57e07a6cb8018 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy.engine import default
 from sqlalchemy.orm import (
     attributes, mapper, relationship, create_session, synonym, Session,
     aliased, column_property, joinedload_all, joinedload, Query, Bundle,
-    subqueryload, backref, lazyload)
+    subqueryload, backref, lazyload, defer)
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.schema import Table, Column
 import sqlalchemy as sa
@@ -1232,6 +1232,70 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
                     Address(email_address='jack@bean.com', user_id=7, id=1))])
 
 
+class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
+    run_setup_mappers = 'each'
+
+    def _fixture(self):
+        User, Address = self.classes("User", "Address")
+        users, addresses = self.tables("users", "addresses")
+        mapper(User, users, properties={
+            "ead": column_property(
+                select([func.max(addresses.c.email_address)]).\
+                    where(addresses.c.user_id == users.c.id).\
+                    correlate(users).label("email_ad")
+            )
+        })
+        mapper(Address, addresses)
+
+    def test_order_by_column_prop_label(self):
+        User, Address = self.classes("User", "Address")
+        self._fixture()
+
+        s = Session()
+        q = s.query(User).order_by("email_ad")
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses "
+            "WHERE addresses.user_id = users.id) AS email_ad, "
+            "users.id AS users_id, users.name AS users_name "
+            "FROM users ORDER BY email_ad"
+        )
+
+    def test_order_by_column_prop_attrname(self):
+        User, Address = self.classes("User", "Address")
+        self._fixture()
+
+        s = Session()
+        q = s.query(User).order_by(User.ead)
+        # this one is a bit of a surprise; this is compiler
+        # label-order-by logic kicking in, but won't work in more
+        # complex cases.
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses "
+            "WHERE addresses.user_id = users.id) AS email_ad, "
+            "users.id AS users_id, users.name AS users_name "
+            "FROM users ORDER BY email_ad"
+        )
+
+    def test_order_by_column_prop_attrname_non_present(self):
+        User, Address = self.classes("User", "Address")
+        self._fixture()
+
+        s = Session()
+        q = s.query(User).options(defer(User.ead)).order_by(User.ead)
+        self.assert_compile(
+            q,
+            "SELECT users.id AS users_id, users.name AS users_name "
+            "FROM users ORDER BY (SELECT max(addresses.email_address) AS max_1 "
+            "FROM addresses "
+            "WHERE addresses.user_id = users.id)"
+        )
+
+
 # more slice tests are available in test/orm/generative.py
 class SliceTest(QueryTest):
     def test_first(self):
index e84a2907c7ff5591bae97fd35dd739b1f1b84658..94627ae073f5b97bf545cc77be6f778785077db6 100644 (file)
@@ -6,7 +6,7 @@ from sqlalchemy import text, select, Integer, String, Float, \
     bindparam, and_, func, literal_column, exc, MetaData, Table, Column,\
     asc, func, desc, union
 from sqlalchemy.types import NullType
-from sqlalchemy.sql import table, column
+from sqlalchemy.sql import table, column, util as sql_util
 from sqlalchemy import util
 
 table1 = table('mytable',
@@ -679,3 +679,29 @@ class OrderByLabelResolutionTest(fixtures.TestBase, AssertsCompiledSQL):
             desc("somelabel"),
             "somelabel DESC"
         )
+
+    def test_anonymized_via_columnadapter(self):
+        """test issue #3148
+
+        Testing the anonymization applied from the ColumnAdapter.columns
+        collection, typically as used in eager loading.
+
+        """
+        exprs = [
+            table1.c.myid,
+            table1.c.name.label('t1name'),
+            func.foo("hoho").label('x')]
+
+        ta = table1.alias()
+        adapter = sql_util.ColumnAdapter(ta)
+
+        s1 = select([adapter.columns[expr] for expr in exprs]).\
+            apply_labels().order_by("myid", "t1name", "x")
+
+        # our "t1name" and "x" labels get modified
+        self.assert_compile(
+            s1,
+            "SELECT mytable_1.myid AS mytable_1_myid, "
+            "mytable_1.name AS name_1, foo(:foo_2) AS foo_1 "
+            "FROM mytable AS mytable_1 ORDER BY mytable_1.myid, name_1, foo_1"
+        )