]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] "scalar" selects now have a WHERE method
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 Oct 2012 21:21:38 +0000 (17:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 Oct 2012 21:21:38 +0000 (17:21 -0400)
    to help with generative building.  Also slight adjustment
    regarding how SS "correlates" columns; the new methodology
    no longer applies meaning to the underlying
    Table column being selected.  This improves
    some fairly esoteric situations, and the logic
    that was there didn't seem to have any purpose.
  - [feature] Some support for auto-rendering of a
    relationship join condition based on the mapped
    attribute, with usage of core SQL constructs.
    E.g. select([SomeClass]).where(SomeClass.somerelationship)
    would render SELECT from "someclass" and use the
    primaryjoin of "somerelationship" as the WHERE
    clause.   This changes the previous meaning
    of "SomeClass.somerelationship" when used in a
    core SQL context; previously, it would "resolve"
    to the parent selectable, which wasn't generally
    useful.  Related to [ticket:2245].

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
test/orm/test_eager_relations.py
test/orm/test_mapper.py
test/orm/test_query.py
test/sql/test_compiler.py
test/sql/test_selectable.py

diff --git a/CHANGES b/CHANGES
index a376bbc3ccdd591ebbfeb4f1ed0a1a846336d13b..c1f4ede430a02c02241bfa3f1d1a8971ac085525 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -285,6 +285,18 @@ underneath "0.7.xx".
     methods, where they will be unwrapped
     into selectables. [ticket:2245]
 
+  - [feature] Some support for auto-rendering of a
+    relationship join condition based on the mapped
+    attribute, with usage of core SQL constructs.
+    E.g. select([SomeClass]).where(SomeClass.somerelationship)
+    would render SELECT from "someclass" and use the
+    primaryjoin of "somerelationship" as the WHERE
+    clause.   This changes the previous meaning
+    of "SomeClass.somerelationship" when used in a
+    core SQL context; previously, it would "resolve"
+    to the parent selectable, which wasn't generally
+    useful.  Related to [ticket:2245].
+
   - [feature] The registry of classes
     in declarative_base() is now a
     WeakValueDictionary.  So subclasses of
@@ -654,6 +666,14 @@ underneath "0.7.xx".
     function against the new schema.CreateColumn
     construct.  [ticket:2463]
 
+  - [feature] "scalar" selects now have a WHERE method
+    to help with generative building.  Also slight adjustment
+    regarding how SS "correlates" columns; the new methodology
+    no longer applies meaning to the underlying
+    Table column being selected.  This improves
+    some fairly esoteric situations, and the logic
+    that was there didn't seem to have any purpose.
+
   - [bug] Fixes to the interpretation of the
     Column "default" parameter as a callable
     to not pass ExecutionContext into a keyword
index f8288f5fb369fe506bffd47a5ef3e8f7e2465944..048b4fad394ce4ca3f4414fff1f906705c08b692 100644 (file)
@@ -350,6 +350,8 @@ class RelationshipProperty(StrategizedProperty):
 
         """
 
+        _of_type = None
+
         def __init__(self, prop, mapper, of_type=None, adapter=None):
             """Construction of :class:`.RelationshipProperty.Comparator`
             is internal to the ORM's attribute mechanics.
@@ -376,13 +378,30 @@ class RelationshipProperty(StrategizedProperty):
         def parententity(self):
             return self.property.parent
 
-        def __clause_element__(self):
+        def _source_selectable(self):
             elem = self.property.parent._with_polymorphic_selectable
             if self.adapter:
                 return self.adapter(elem)
             else:
                 return elem
 
+        def __clause_element__(self):
+            adapt_from = self._source_selectable()
+            if self._of_type:
+                of_type = inspect(self._of_type).mapper
+            else:
+                of_type = None
+
+            pj, sj, source, dest, \
+            secondary, target_adapter = self.property._create_joins(
+                            source_selectable=adapt_from,
+                            source_polymorphic=True,
+                            of_type=of_type)
+            if sj is not None:
+                return pj & sj
+            else:
+                return pj
+
         def of_type(self, cls):
             """Produce a construct that represents a particular 'subtype' of
             attribute for the parent class.
@@ -477,7 +496,7 @@ class RelationshipProperty(StrategizedProperty):
                 to_selectable = None
 
             if self.adapter:
-                source_selectable = self.__clause_element__()
+                source_selectable = self._source_selectable()
             else:
                 source_selectable = None
 
index 750c3b298821445b1736873655326ce55f220275..2e38e0ce360242567560e323553fc1bde94c6f57 100644 (file)
@@ -796,7 +796,7 @@ class _ORMJoin(expression.Join):
                 prop = left_mapper.get_property(onclause)
             elif isinstance(onclause, attributes.QueryableAttribute):
                 if adapt_from is None:
-                    adapt_from = onclause.__clause_element__()
+                    adapt_from = onclause.comparator._source_selectable()
                 prop = onclause.property
             elif isinstance(onclause, MapperProperty):
                 prop = onclause
index 63b1a4037d0acb5c4047558d9b699d8ed9456667..5b6e4d82df261aca7b760424c10bb60912e7582c 100644 (file)
@@ -4839,7 +4839,7 @@ class SelectBase(Executable, FromClause):
         return [self]
 
 
-class ScalarSelect(Grouping):
+class ScalarSelect(Generative, Grouping):
     _from_objects = []
 
     def __init__(self, element):
@@ -4853,13 +4853,17 @@ class ScalarSelect(Grouping):
                 'column-level expression.')
     c = columns
 
+    @_generative
+    def where(self, crit):
+        """Apply a WHERE clause to the SELECT statement referred to
+        by this :class:`.ScalarSelect`.
+
+        """
+        self.element = self.element.where(crit)
+
     def self_group(self, **kwargs):
         return self
 
-    def _make_proxy(self, selectable, name=None, **kw):
-        return list(self.inner_columns)[0]._make_proxy(
-                            selectable, name=name)
-
 class CompoundSelect(SelectBase):
     """Forms the basis of ``UNION``, ``UNION ALL``, and other
         SELECT-based set operations."""
index 664ddb8e3047fd6c7b788834c672ee7e1ecc2699..39a7a7301daa21ca982e91fb8362d55840fd932b 100644 (file)
@@ -1442,11 +1442,12 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         b_table, a_table = self.tables.b, self.tables.a
         self._fixture({})
         cp = select([func.sum(b_table.c.value)]).\
-                        where(b_table.c.a_id==a_table.c.id).\
+                        where(b_table.c.a_id == a_table.c.id).\
                         correlate(a_table).as_scalar()
-        # note its re-rendering the subquery in the
-        # outermost order by.  usually we want it to address
-        # the column within the subquery.  labelling fixes that.
+
+        # up until 0.8, this was ordering by a new subquery.
+        # the removal of a separate _make_proxy() from ScalarSelect
+        # fixed that.
         self.assert_compile(
             create_session().query(A).options(joinedload_all('bs')).
                             order_by(cp).
@@ -1458,8 +1459,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "b.a_id = a.id) AS anon_2 FROM a ORDER BY (SELECT "
             "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
             "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 "
-            "ON anon_1.a_id = b_1.a_id ORDER BY "
-            "(SELECT anon_1.anon_2 FROM b WHERE b.a_id = anon_1.a_id)"
+            "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2"
         )
 
     def test_standalone_subquery_labeled(self):
index b2e36b273616a8b8f0c5369f2b417dc383be90a6..f41843455c3f968ab16c09045b99f2da075d7ed9 100644 (file)
@@ -2190,28 +2190,46 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL):
 
         from sqlalchemy.orm.properties import PropertyLoader
 
+        # NOTE: this API changed in 0.8, previously __clause_element__()
+        # gave the parent selecatable, now it gives the
+        # primaryjoin/secondaryjoin
         class MyFactory(PropertyLoader.Comparator):
             __hash__ = None
             def __eq__(self, other):
-                return func.foobar(self.__clause_element__().c.user_id) == func.foobar(other.id)
+                return func.foobar(self._source_selectable().c.user_id) == \
+                    func.foobar(other.id)
 
         class MyFactory2(PropertyLoader.Comparator):
             __hash__ = None
             def __eq__(self, other):
-                return func.foobar(self.__clause_element__().c.id) == func.foobar(other.user_id)
+                return func.foobar(self._source_selectable().c.id) == \
+                    func.foobar(other.user_id)
 
         mapper(User, users)
         mapper(Address, addresses, properties={
-            'user':relationship(User, comparator_factory=MyFactory,
+            'user': relationship(User, comparator_factory=MyFactory,
                 backref=backref("addresses", comparator_factory=MyFactory2)
             )
             }
         )
-        self.assert_compile(Address.user == User(id=5), "foobar(addresses.user_id) = foobar(:foobar_1)", dialect=default.DefaultDialect())
-        self.assert_compile(User.addresses == Address(id=5, user_id=7), "foobar(users.id) = foobar(:foobar_1)", dialect=default.DefaultDialect())
 
-        self.assert_compile(aliased(Address).user == User(id=5), "foobar(addresses_1.user_id) = foobar(:foobar_1)", dialect=default.DefaultDialect())
-        self.assert_compile(aliased(User).addresses == Address(id=5, user_id=7), "foobar(users_1.id) = foobar(:foobar_1)", dialect=default.DefaultDialect())
+        # these are kind of nonsensical tests.
+        self.assert_compile(Address.user == User(id=5),
+                "foobar(addresses.user_id) = foobar(:foobar_1)",
+                dialect=default.DefaultDialect())
+        self.assert_compile(User.addresses == Address(id=5, user_id=7),
+                "foobar(users.id) = foobar(:foobar_1)",
+                dialect=default.DefaultDialect())
+
+        self.assert_compile(
+                aliased(Address).user == User(id=5),
+                "foobar(addresses_1.user_id) = foobar(:foobar_1)",
+                dialect=default.DefaultDialect())
+
+        self.assert_compile(
+                aliased(User).addresses == Address(id=5, user_id=7),
+                "foobar(users_1.id) = foobar(:foobar_1)",
+                dialect=default.DefaultDialect())
 
 
 class DeferredTest(_fixtures.FixtureTest):
index 56275a73513a114528f7c82b10e7e3dd02bcf75b..52f83ba32bc8be2bf116b55cac40f3cc7da0340d 100644 (file)
@@ -129,6 +129,26 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
             "SELECT * FROM users"
         )
 
+    def test_where_relationship(self):
+        User = self.classes.User
+
+        self.assert_compile(
+            select([User]).where(User.addresses),
+            "SELECT users.id, users.name FROM users, addresses "
+            "WHERE users.id = addresses.user_id"
+        )
+
+    def test_where_m2m_relationship(self):
+        Item = self.classes.Item
+
+        self.assert_compile(
+            select([Item]).where(Item.keywords),
+            "SELECT items.id, items.description FROM items, "
+            "item_keywords AS item_keywords_1, keywords "
+            "WHERE items.id = item_keywords_1.item_id "
+            "AND keywords.id = item_keywords_1.keyword_id"
+        )
+
     def test_inline_select_from_entity(self):
         User = self.classes.User
 
index b09ae1ab008374c82481c47ed75bf3df5e5ce9b6..53e62601069e7fdd25fa9dfec60de069b51db1ed 100644 (file)
@@ -620,6 +620,14 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                             'mytable.description, (SELECT mytable.myid '
                             'FROM mytable) AS anon_1 FROM mytable')
 
+        s = select([table1.c.myid]).as_scalar()
+        s2 = s.where(table1.c.myid == 5)
+        self.assert_compile(
+            s2, "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)"
+        )
+        self.assert_compile(
+            s, "(SELECT mytable.myid FROM mytable)"
+        )
         # test that aliases use as_scalar() when used in an explicitly
         # scalar context
 
@@ -2018,7 +2026,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             expr,
             "x = :key",
-            {'x':12}
+            {'x': 12}
         )
 
     def test_bind_params_missing(self):
@@ -2114,7 +2122,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(table1.c.myid.in_([literal('a'), table1.c.myid]),
         "mytable.myid IN (:param_1, mytable.myid)")
 
-        self.assert_compile(table1.c.myid.in_([literal('a'), table1.c.myid +'a']),
+        self.assert_compile(table1.c.myid.in_([literal('a'), table1.c.myid + 'a']),
         "mytable.myid IN (:param_1, mytable.myid + :myid_1)")
 
         self.assert_compile(table1.c.myid.in_([literal(1), 'a' + table1.c.myid]),
@@ -2144,11 +2152,13 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             select([table1.c.myid.in_(select([table2.c.otherid]))]),
-            "SELECT mytable.myid IN (SELECT myothertable.otherid FROM myothertable) AS anon_1 FROM mytable"
+            "SELECT mytable.myid IN (SELECT myothertable.otherid "
+                "FROM myothertable) AS anon_1 FROM mytable"
         )
         self.assert_compile(
             select([table1.c.myid.in_(select([table2.c.otherid]).as_scalar())]),
-            "SELECT mytable.myid IN (SELECT myothertable.otherid FROM myothertable) AS anon_1 FROM mytable"
+            "SELECT mytable.myid IN (SELECT myothertable.otherid "
+                "FROM myothertable) AS anon_1 FROM mytable"
         )
 
         self.assert_compile(table1.c.myid.in_(
@@ -2160,17 +2170,24 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         "SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1 "\
         "UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
 
-        # test that putting a select in an IN clause does not blow away its ORDER BY clause
+        # test that putting a select in an IN clause does not
+        # blow away its ORDER BY clause
         self.assert_compile(
             select([table1, table2],
                 table2.c.otherid.in_(
-                    select([table2.c.otherid], order_by=[table2.c.othername], limit=10, correlate=False)
+                    select([table2.c.otherid], order_by=[table2.c.othername],
+                                        limit=10, correlate=False)
                 ),
-                from_obj=[table1.join(table2, table1.c.myid==table2.c.otherid)], order_by=[table1.c.myid]
+                from_obj=[table1.join(table2,
+                            table1.c.myid == table2.c.otherid)],
+                order_by=[table1.c.myid]
             ),
-            "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable "\
-            "JOIN myothertable ON mytable.myid = myothertable.otherid WHERE myothertable.otherid IN (SELECT myothertable.otherid "\
-            "FROM myothertable ORDER BY myothertable.othername LIMIT :param_1) ORDER BY mytable.myid",
+            "SELECT mytable.myid, mytable.name, mytable.description, "
+            "myothertable.otherid, myothertable.othername FROM mytable "\
+            "JOIN myothertable ON mytable.myid = myothertable.otherid "
+            "WHERE myothertable.otherid IN (SELECT myothertable.otherid "\
+            "FROM myothertable ORDER BY myothertable.othername "
+            "LIMIT :param_1) ORDER BY mytable.myid",
             {'param_1':10}
         )
 
index 374147a1b70b33423cda5f131e9fd4b65b321426..bbf7eeab12b3cf02f90a37ba1ee0795c190b9053 100644 (file)
@@ -62,11 +62,11 @@ class SelectableTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
 
         eq_(
             s1.c.foo.proxy_set,
-            set([s1.c.foo, scalar_select, scalar_select.element, table1.c.col1])
+            set([s1.c.foo, scalar_select, scalar_select.element])
         )
         eq_(
             s2.c.foo.proxy_set,
-            set([s2.c.foo, scalar_select, scalar_select.element, table1.c.col1])
+            set([s2.c.foo, scalar_select, scalar_select.element])
         )
 
         assert s1.corresponding_column(scalar_select) is s1.c.foo