]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Improved query.join() such that the "left" side
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Oct 2011 19:15:28 +0000 (15:15 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Oct 2011 19:15:28 +0000 (15:15 -0400)
can more flexibly be a non-ORM selectable,
such as a subquery.   A selectable placed
in select_from() will now be used as the left
side, favored over implicit usage
of a mapped entity.
If the join still fails based on lack of
foreign keys, the error message includes
this detail.  Thanks to brianrhude
on IRC for the test case.  [ticket:2298]

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/test_froms.py
test/orm/test_joins.py

diff --git a/CHANGES b/CHANGES
index 7d3dd04ec13e31ac63d87904932ff6188e045dd0..d072eba6ec3cb7bf5bc5d1f009c0d5cacdb54c21 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -15,6 +15,17 @@ CHANGES
      fixes [ticket:2279].  Also in 0.6.9.
 
 - orm
+   - Improved query.join() such that the "left" side
+     can more flexibly be a non-ORM selectable, 
+     such as a subquery.   A selectable placed
+     in select_from() will now be used as the left
+     side, favored over implicit usage
+     of a mapped entity.
+     If the join still fails based on lack of
+     foreign keys, the error message includes 
+     this detail.  Thanks to brianrhude
+     on IRC for the test case.  [ticket:2298]
+
    - Added after_soft_rollback() Session event.  This
      event fires unconditionally whenever rollback()
      is called, regardless of if an actual DBAPI
index 6ef6c3b573773417b8da28f77f62095a420f4692..88532d90961fd799b375ed0c46c6e5dd5a4c14de 100644 (file)
@@ -1626,7 +1626,10 @@ class Query(object):
         """append a JOIN to the query's from clause."""
 
         if left is None:
-            left = self._joinpoint_zero()
+            if self._from_obj:
+                left = self._from_obj[0]
+            elif self._entities:
+                left = self._entities[0].entity_zero_or_selectable
 
         if left is right and \
                 not create_aliases:
@@ -1742,10 +1745,15 @@ class Query(object):
                     sql_util.clause_is_present(left_selectable, clause):
                     join_to_left = False
 
-                clause = orm_join(clause, 
+                try:
+                    clause = orm_join(clause, 
                                     right, 
                                     onclause, isouter=outerjoin, 
                                     join_to_left=join_to_left)
+                except sa_exc.ArgumentError, ae:
+                    raise sa_exc.InvalidRequestError(
+                            "Could not find a FROM clause to join from.  "
+                            "Tried joining to %s, but got: %s" % (right, ae))
 
                 self._from_obj = \
                         self._from_obj[:replace_clause_index] + \
@@ -1760,6 +1768,8 @@ class Query(object):
                     break
             else:
                 clause = left
+        elif left_selectable is not None:
+            clause = left_selectable
         else:
             clause = None
 
@@ -1767,8 +1777,13 @@ class Query(object):
             raise sa_exc.InvalidRequestError(
                     "Could not find a FROM clause to join from")
 
-        clause = orm_join(clause, right, onclause, 
+        try:
+            clause = orm_join(clause, right, onclause, 
                                 isouter=outerjoin, join_to_left=join_to_left)
+        except sa_exc.ArgumentError, ae:
+            raise sa_exc.InvalidRequestError(
+                    "Could not find a FROM clause to join from.  "
+                    "Tried joining to %s, but got: %s" % (right, ae))
 
         self._from_obj = self._from_obj + (clause,)
 
@@ -2890,6 +2905,10 @@ class _MapperEntity(_QueryEntity):
     def type(self):
         return self.mapper.class_
 
+    @property
+    def entity_zero_or_selectable(self):
+        return self.entity_zero
+
     def corresponds_to(self, entity):
         if _is_aliased_class(entity) or self.is_aliased_class:
             return entity is self.path_entity
@@ -3055,7 +3074,7 @@ class _ColumnEntity(_QueryEntity):
         # of FROMs for the overall expression - this helps
         # subqueries which were built from ORM constructs from
         # leaking out their entities into the main select construct
-        actual_froms = set(column._from_objects)
+        self.actual_froms = actual_froms = set(column._from_objects)
 
         self.entities = util.OrderedSet(
             elem._annotations['parententity']
@@ -3069,6 +3088,15 @@ class _ColumnEntity(_QueryEntity):
         else:
             self.entity_zero = None
 
+    @property
+    def entity_zero_or_selectable(self):
+        if self.entity_zero:
+            return self.entity_zero
+        elif self.actual_froms:
+            return list(self.actual_froms)[0]
+        else:
+            return None
+
     @property
     def type(self):
         return self.column.type
index e2fb55129d4f7dd87954a5d530ebcdfa62bc2382..c8fc0af79bcfcf5b8b8244c17ab698858dc037bf 100644 (file)
@@ -1610,6 +1610,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
 
 class SelectFromTest(QueryTest, AssertsCompiledSQL):
     run_setup_mappers = None
+    __dialect__ = 'default'
 
     def test_replace_with_select(self):
         users, Address, addresses, User = (self.tables.users,
@@ -1676,7 +1677,6 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL):
             "SELECT users.id AS users_id, users.name AS users_name FROM "
             "users JOIN (SELECT users.id AS id, users.name AS name FROM "
             "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 ON users.id > anon_1.id",
-            use_default_dialect=True
         )
 
         self.assert_compile(
@@ -1684,18 +1684,22 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL):
             "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name FROM "
             "users AS users_1, (SELECT users.id AS id, users.name AS name FROM "
             "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 WHERE users_1.id > anon_1.id",
-            use_default_dialect=True
         )
 
-        # these two are essentially saying, "join ualias to ualias", so an 
-        # error is raised.  join() deals with entities, not what's in
-        # select_from().
-        assert_raises(sa_exc.InvalidRequestError,
-            sess.query(ualias).select_from(sel).join, ualias, ualias.id>sel.c.id
+        self.assert_compile(
+            sess.query(ualias).select_from(sel).join(ualias, ualias.id>sel.c.id),
+            "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name "
+            "FROM (SELECT users.id AS id, users.name AS name "
+            "FROM users WHERE users.id IN (:id_1, :id_2)) AS anon_1 "
+            "JOIN users AS users_1 ON users_1.id > anon_1.id"
         )
 
-        assert_raises(sa_exc.InvalidRequestError,
-            sess.query(ualias).select_from(sel).join, ualias, ualias.id>User.id
+        self.assert_compile(
+            sess.query(ualias).select_from(sel).join(ualias, ualias.id>User.id),
+            "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name "
+            "FROM (SELECT users.id AS id, users.name AS name FROM "
+            "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 "
+            "JOIN users AS users_1 ON anon_1.id < users_1.id"
         )
 
         salias = aliased(User, sel)
@@ -1704,7 +1708,6 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL):
             "SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name FROM "
             "(SELECT users.id AS id, users.name AS name FROM users WHERE users.id "
             "IN (:id_1, :id_2)) AS anon_1 JOIN users AS users_1 ON users_1.id > anon_1.id",
-            use_default_dialect=True
         )
 
 
index ccd42c10bd536793ad3bf18c69540d9f0ae51737..db7c78cdd697bde82f8901407a52e54f520a6d23 100644 (file)
@@ -881,7 +881,11 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
 
         sess = create_session()
 
-        assert_raises_message(sa.exc.InvalidRequestError, "Could not find a FROM clause to join from", sess.query(users).join, addresses)
+        self.assert_compile(
+            sess.query(users).join(addresses),
+            "SELECT users.id AS users_id, users.name AS users_name "
+            "FROM users JOIN addresses ON users.id = addresses.user_id"
+        )
 
 
     def test_orderby_arg_bug(self):
@@ -1261,13 +1265,19 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
 
         assert_raises_message(
             sa_exc.InvalidRequestError,
-            "Could not find a FROM",
+            "Could not find a FROM clause to join from.  Tried joining "
+            "to .*?, but got: "
+            "Can't find any foreign key relationships "
+            "between 'users' and 'users'.",
             sess.query(users.c.id).join, User
         )
 
         assert_raises_message(
             sa_exc.InvalidRequestError,
-            "Could not find a FROM",
+            "Could not find a FROM clause to join from.  Tried joining "
+            "to .*?, but got: "
+            "Can't find any foreign key relationships "
+            "between 'users' and 'users'.",
             sess.query(users.c.id).select_from(users).join, User
         )
 
@@ -1322,6 +1332,151 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             use_default_dialect=True
         )
 
+class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
+    run_setup_mappers = 'once'
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('table1', metadata, 
+            Column('id', Integer, primary_key=True)
+        )
+        Table('table2', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('t1_id', Integer)
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        table1, table2 = cls.tables.table1, cls.tables.table2
+        class T1(cls.Comparable):
+            pass
+
+        class T2(cls.Comparable):
+            pass
+
+        mapper(T1, table1)
+        mapper(T2, table2)
+
+    def test_select_mapped_to_mapped_explicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        self.assert_compile(
+            sess.query(subq.c.count, T1.id).select_from(subq).join(T1, subq.c.t1_id==T1.id),
+            "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id "
+            "FROM (SELECT table2.t1_id AS t1_id, "
+            "count(table2.id) AS count FROM table2 "
+            "GROUP BY table2.t1_id) AS anon_1 JOIN table1 ON anon_1.t1_id = table1.id"
+        )
+
+    def test_select_mapped_to_mapped_implicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        self.assert_compile(
+            sess.query(subq.c.count, T1.id).join(T1, subq.c.t1_id==T1.id),
+            "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id "
+            "FROM (SELECT table2.t1_id AS t1_id, "
+            "count(table2.id) AS count FROM table2 "
+            "GROUP BY table2.t1_id) AS anon_1 JOIN table1 ON anon_1.t1_id = table1.id"
+        )
+
+    def test_select_mapped_to_select_explicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        self.assert_compile(
+            sess.query(subq.c.count, T1.id).select_from(T1).join(subq, subq.c.t1_id==T1.id),
+            "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id "
+            "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, "
+            "count(table2.id) AS count FROM table2 GROUP BY table2.t1_id) "
+            "AS anon_1 ON anon_1.t1_id = table1.id"
+        )
+
+    def test_select_mapped_to_select_implicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't construct a join from ",
+            sess.query(subq.c.count, T1.id).join, subq, subq.c.t1_id==T1.id,
+        )
+
+    def test_mapped_select_to_mapped_implicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        # this query is wrong, but verifying behavior stays the same
+        # (or improves, like an error message)
+        self.assert_compile(
+            sess.query(T1.id, subq.c.count).join(T1, subq.c.t1_id==T1.id),
+            "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count FROM "
+            "(SELECT table2.t1_id AS t1_id, count(table2.id) AS count FROM "
+            "table2 GROUP BY table2.t1_id) AS anon_1, table1 JOIN table1 "
+            "ON anon_1.t1_id = table1.id"
+        )
+
+    def test_mapped_select_to_mapped_explicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        self.assert_compile(
+            sess.query(T1.id, subq.c.count).select_from(subq).join(T1, subq.c.t1_id==T1.id),
+            "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count "
+            "FROM (SELECT table2.t1_id AS t1_id, count(table2.id) AS count "
+            "FROM table2 GROUP BY table2.t1_id) AS anon_1 JOIN table1 "
+            "ON anon_1.t1_id = table1.id"
+        )
+
+    def test_mapped_select_to_select_explicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        self.assert_compile(
+            sess.query(T1.id, subq.c.count).select_from(T1).join(subq, subq.c.t1_id==T1.id),
+            "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count "
+            "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, count(table2.id) AS count "
+            "FROM table2 GROUP BY table2.t1_id) AS anon_1 "
+            "ON anon_1.t1_id = table1.id"
+        )
+
+    def test_mapped_select_to_select_implicit_left(self):
+        T1, T2 = self.classes.T1, self.classes.T2
+
+        sess = Session()
+        subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
+                    group_by(T2.t1_id).subquery()
+
+        self.assert_compile(
+            sess.query(T1.id, subq.c.count).join(subq, subq.c.t1_id==T1.id),
+            "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count "
+            "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, count(table2.id) AS count "
+            "FROM table2 GROUP BY table2.t1_id) AS anon_1 "
+            "ON anon_1.t1_id = table1.id"
+        )
 
 class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL):
     @classmethod