]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- improve overlapping selectables, apply to both query and relationship
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2013 22:23:06 +0000 (18:23 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2013 22:23:06 +0000 (18:23 -0400)
- clean up inspect() calls within query._join()
- make sure join.alias(flat) propagates
- fix almost all assertion tests

lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/test_relationship.py
test/orm/test_joins.py
test/orm/test_mapper.py
test/orm/test_of_type.py

index 8ab81cb130e1f382c3a2d6a3aa40668683db9c70..39ed8d8bfee29398544cd52aaa5efc868ea2de7d 100644 (file)
@@ -1797,14 +1797,6 @@ class Query(object):
                                 right_entity, onclause,
                                 outerjoin, create_aliases, prop)
 
-    def _tables_overlap(self, left, right):
-        """Return True if parent/child tables have some overlap."""
-
-        return  bool(
-            set(sql_util.find_tables(left)).intersection(
-                sql_util.find_tables(right)
-            )
-        )
 
     def _join_left_to_right(self, left, right,
                             onclause, outerjoin, create_aliases, prop):
@@ -1825,14 +1817,19 @@ class Query(object):
                         "are the same entity" %
                         (left, right))
 
-        # TODO: get the l_info, r_info passed into
-        # the methods so inspect() doesnt need to be called again
         l_info = inspect(left)
         r_info = inspect(right)
-        overlap = self._tables_overlap(l_info.selectable, r_info.selectable)
+
+        overlap = not create_aliases and \
+                        sql_util.selectables_overlap(l_info.selectable,
+                            r_info.selectable)
+        if overlap and l_info.selectable is r_info.selectable:
+            raise sa_exc.InvalidRequestError(
+                    "Can't join table/selectable '%s' to itself" %
+                        l_info.selectable)
 
         right, onclause = self._prepare_right_side(
-                                            right, onclause,
+                                r_info, right, onclause,
                                             create_aliases,
                                             prop, overlap)
 
@@ -1846,10 +1843,11 @@ class Query(object):
         else:
             self._joinpoint = {'_joinpoint_entity': right}
 
-        self._join_to_left(left, right, onclause, outerjoin)
+        self._join_to_left(l_info, left, right, onclause, outerjoin)
 
-    def _prepare_right_side(self, right, onclause, create_aliases, prop, overlap):
-        info = inspect(right)
+    def _prepare_right_side(self, r_info, right, onclause, create_aliases,
+                                    prop, overlap):
+        info = r_info
 
         right_mapper, right_selectable, right_is_aliased = \
             getattr(info, 'mapper', None), \
@@ -1931,8 +1929,8 @@ class Query(object):
 
         return right, onclause
 
-    def _join_to_left(self, left, right, onclause, outerjoin):
-        info = inspect(left)
+    def _join_to_left(self, l_info, left, right, onclause, outerjoin):
+        info = l_info
         left_mapper = getattr(info, 'mapper', None)
         left_selectable = info.selectable
 
index 95fa28613476fbd9747351dcc25573d5f4eafa94..33377d3ec17aeb3ce51b7286f998d73803451e10 100644 (file)
@@ -17,7 +17,7 @@ from .. import sql, util, exc as sa_exc, schema
 from ..sql.util import (
     ClauseAdapter,
     join_condition, _shallow_annotate, visit_binary_product,
-    _deep_deannotate, find_tables
+    _deep_deannotate, find_tables, selectables_overlap
     )
 from ..sql import operators, expression, visitors
 from .interfaces import MANYTOMANY, MANYTOONE, ONETOMANY
@@ -404,11 +404,7 @@ class JoinCondition(object):
     def _tables_overlap(self):
         """Return True if parent/child tables have some overlap."""
 
-        return  bool(
-            set(find_tables(self.parent_selectable)).intersection(
-                find_tables(self.child_selectable)
-            )
-        )
+        return selectables_overlap(self.parent_selectable, self.child_selectable)
 
     def _annotate_remote(self):
         """Annotate the primaryjoin and secondaryjoin
index 633a3ddba7b7bca4328cac01455a8dbd354e7a86..e7ef3cb7288ea5895dd4ec82aa9367557917f63a 100644 (file)
@@ -4001,7 +4001,8 @@ class Join(FromClause):
         """
         if flat:
             assert name is None, "Can't send name argument with flat"
-            left_a, right_a = self.left.alias(), self.right.alias()
+            left_a, right_a = self.left.alias(flat=True), \
+                                self.right.alias(flat=True)
             adapter = sqlutil.ClauseAdapter(left_a).\
                         chain(sqlutil.ClauseAdapter(right_a))
 
index c80693706632d98c1d933c7726a3b3b21c2db896..6f4d27e1bdebce3e075de0aa9c8ea294ec9944a4 100644 (file)
@@ -200,15 +200,28 @@ def clause_is_present(clause, search):
 
     """
 
-    stack = [search]
-    while stack:
-        elem = stack.pop()
+    for elem in surface_selectables(search):
         if clause == elem:  # use == here so that Annotated's compare
             return True
-        elif isinstance(elem, expression.Join):
+    else:
+        return False
+
+def surface_selectables(clause):
+    stack = [clause]
+    while stack:
+        elem = stack.pop()
+        yield elem
+        if isinstance(elem, expression.Join):
             stack.extend((elem.left, elem.right))
-    return False
 
+def selectables_overlap(left, right):
+    """Return True if left/right have some overlapping selectable"""
+
+    return bool(
+                set(surface_selectables(left)).intersection(
+                        surface_selectables(right)
+                    )
+            )
 
 def bind_values(clause):
     """Return an ordered list of "bound" values in the given clause.
index d8cf5ebd8a7a2387152d818f3b361b4e6450fd8f..3f1eb849fb646ce720df055dbb9f3b9ba6ffc870 100644 (file)
@@ -546,11 +546,10 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL):
             "SELECT child2.id AS child2_id, parent.id AS parent_id, "
             "parent.cls AS parent_cls FROM secondary AS secondary_1, "
             "parent JOIN child2 ON parent.id = child2.id JOIN secondary AS "
-            "secondary_2 ON parent.id = secondary_2.left_id JOIN (SELECT "
-            "parent.id AS parent_id, parent.cls AS parent_cls, child1.id AS "
-            "child1_id FROM parent JOIN child1 ON parent.id = child1.id) AS "
-            "anon_1 ON anon_1.parent_id = secondary_2.right_id WHERE "
-            "anon_1.parent_id = secondary_1.right_id AND :param_1 = "
+            "secondary_2 ON parent.id = secondary_2.left_id JOIN "
+            "(parent AS parent_1 JOIN child1 AS child1_1 ON parent_1.id = child1_1.id) "
+            "ON parent_1.id = secondary_2.right_id WHERE "
+            "parent_1.id = secondary_1.right_id AND :param_1 = "
             "secondary_1.left_id",
             dialect=default.DefaultDialect()
         )
index 2dac591501c4d021b608fa5c2ebd5683b6d406b7..cb9412e1dbfd8dd54d83dc723768b1f2eaa2b622 100644 (file)
@@ -203,15 +203,11 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
 
         self.assert_compile(
             sess.query(Company).join(Company.employees.of_type(Engineer)),
-            "SELECT companies.company_id AS companies_company_id, companies.name AS companies_name "
-            "FROM companies JOIN (SELECT people.person_id AS people_person_id, "
-            "people.company_id AS people_company_id, people.name AS people_name, "
-            "people.type AS people_type, engineers.person_id AS "
-            "engineers_person_id, engineers.status AS engineers_status, "
-            "engineers.engineer_name AS engineers_engineer_name, "
-            "engineers.primary_language AS engineers_primary_language "
-            "FROM people JOIN engineers ON people.person_id = engineers.person_id) AS "
-            "anon_1 ON companies.company_id = anon_1.people_company_id"
+            "SELECT companies.company_id AS companies_company_id, "
+            "companies.name AS companies_name "
+            "FROM companies JOIN "
+            "(people JOIN engineers ON people.person_id = engineers.person_id) "
+            "ON companies.company_id = people.company_id"
             , use_default_dialect = True
         )
 
@@ -259,7 +255,7 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
             , use_default_dialect=True
         )
 
-    def test_explicit_polymorphic_join(self):
+    def test_explicit_polymorphic_join_one(self):
         Company, Engineer = self.classes.Company, self.classes.Engineer
 
         sess = create_session()
@@ -268,35 +264,28 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
             sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad'),
             "SELECT companies.company_id AS companies_company_id, companies.name AS "
             "companies_name "
-            "FROM companies JOIN (SELECT people.person_id AS people_person_id, "
-            "people.company_id AS "
-            "people_company_id, people.name AS people_name, people.type AS people_type,"
-            " engineers.person_id AS "
-            "engineers_person_id, engineers.status AS engineers_status, "
-            "engineers.engineer_name AS engineers_engineer_name, "
-            "engineers.primary_language AS engineers_primary_language "
-            "FROM people JOIN engineers ON people.person_id = engineers.person_id) "
-            "AS anon_1 ON "
-            "companies.company_id = anon_1.people_company_id "
-            "WHERE anon_1.engineers_engineer_name = :engineer_name_1"
+            "FROM companies JOIN (people JOIN engineers "
+                "ON people.person_id = engineers.person_id) "
+            "ON "
+            "companies.company_id = people.company_id "
+            "WHERE engineers.engineer_name = :engineer_name_1"
             , use_default_dialect=True
         )
+
+    def test_explicit_polymorphic_join_two(self):
+        Company, Engineer = self.classes.Company, self.classes.Engineer
+
+        sess = create_session()
         self.assert_compile(
             sess.query(Company).join(Engineer, Company.company_id==Engineer.company_id).
                     filter(Engineer.engineer_name=='vlad'),
             "SELECT companies.company_id AS companies_company_id, companies.name "
             "AS companies_name "
-            "FROM companies JOIN (SELECT people.person_id AS people_person_id, "
-            "people.company_id AS "
-            "people_company_id, people.name AS people_name, people.type AS "
-            "people_type, engineers.person_id AS "
-            "engineers_person_id, engineers.status AS engineers_status, "
-            "engineers.engineer_name AS engineers_engineer_name, "
-            "engineers.primary_language AS engineers_primary_language "
-            "FROM people JOIN engineers ON people.person_id = engineers.person_id) AS "
-            "anon_1 ON "
-            "companies.company_id = anon_1.people_company_id "
-            "WHERE anon_1.engineers_engineer_name = :engineer_name_1"
+            "FROM companies JOIN "
+            "(people JOIN engineers ON people.person_id = engineers.person_id) "
+            "ON "
+            "companies.company_id = people.company_id "
+            "WHERE engineers.engineer_name = :engineer_name_1"
             , use_default_dialect=True
         )
 
@@ -319,16 +308,10 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
                 filter(Engineer.name=='dilbert'),
             "SELECT companies.company_id AS companies_company_id, companies.name AS "
             "companies_name "
-            "FROM companies JOIN (SELECT people.person_id AS people_person_id, "
-            "people.company_id AS "
-            "people_company_id, people.name AS people_name, people.type AS "
-            "people_type, engineers.person_id "
-            "AS engineers_person_id, engineers.status AS engineers_status, "
-            "engineers.engineer_name AS engineers_engineer_name, "
-            "engineers.primary_language AS engineers_primary_language FROM people "
+            "FROM companies JOIN (people "
             "JOIN engineers ON people.person_id = "
-            "engineers.person_id) AS anon_1 ON companies.company_id = "
-            "anon_1.people_company_id WHERE anon_1.people_name = :name_1"
+            "engineers.person_id) ON companies.company_id = "
+            "people.company_id WHERE people.name = :name_1"
             , use_default_dialect = True
         )
 
@@ -339,20 +322,14 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
                 filter(Engineer.name=='dilbert').filter(Machine.name=='foo'),
             "SELECT companies.company_id AS companies_company_id, companies.name AS "
             "companies_name "
-            "FROM companies JOIN (SELECT people.person_id AS people_person_id, "
-            "people.company_id AS "
-            "people_company_id, people.name AS people_name, people.type AS people_type,"
-            " engineers.person_id "
-            "AS engineers_person_id, engineers.status AS engineers_status, "
-            "engineers.engineer_name AS engineers_engineer_name, "
-            "engineers.primary_language AS engineers_primary_language FROM people "
+            "FROM companies JOIN (people "
             "JOIN engineers ON people.person_id = "
-            "engineers.person_id) AS anon_1 ON companies.company_id = "
-            "anon_1.people_company_id JOIN "
+            "engineers.person_id) ON companies.company_id = "
+            "people.company_id JOIN "
             "(SELECT machines.machine_id AS machine_id, machines.name AS name, "
             "machines.engineer_id AS engineer_id "
-            "FROM machines) AS anon_2 ON anon_1.engineers_person_id = anon_2.engineer_id "
-            "WHERE anon_1.people_name = :name_1 AND anon_2.name = :name_2"
+            "FROM machines) AS anon_1 ON engineers.person_id = anon_1.engineer_id "
+            "WHERE people.name = :name_1 AND anon_1.name = :name_2"
             , use_default_dialect = True
         )
 
@@ -1364,19 +1341,13 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
 
         assert_raises_message(
             sa_exc.InvalidRequestError,
-            "Could not find a FROM clause to join from.  Tried joining "
-            "to .*?, but got: "
-            "Can't find any foreign key relationships "
-            "between 'users' and 'users'.",
+            "Can't join table/selectable 'users' to itself",
             sess.query(users.c.id).join, User
         )
 
         assert_raises_message(
             sa_exc.InvalidRequestError,
-            "Could not find a FROM clause to join from.  Tried joining "
-            "to .*?, but got: "
-            "Can't find any foreign key relationships "
-            "between 'users' and 'users'.",
+            "Can't join table/selectable 'users' to itself",
             sess.query(users.c.id).select_from(users).join, User
         )
 
@@ -1522,16 +1493,22 @@ class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL):
         subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\
                     group_by(T2.t1_id).subquery()
 
-        # this query is wrong, but verifying behavior stays the same
-        # (or improves, like an error message)
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "Can't join table/selectable 'table1' to itself",
+            sess.query(T1.id, subq.c.count).join, T1, subq.c.t1_id == T1.id
+        )
+
         self.assert_compile(
-            sess.query(T1.id, subq.c.count).join(T1, subq.c.t1_id==T1.id),
-            "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count FROM "
-            "(SELECT table2.t1_id AS t1_id, count(table2.id) AS count FROM "
-            "table2 GROUP BY table2.t1_id) AS anon_1, table1 JOIN table1 "
-            "ON anon_1.t1_id = table1.id"
+            sess.query(T1.id, subq.c.count).select_from(subq).\
+                join(T1, subq.c.t1_id == T1.id),
+            "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count "
+            "FROM (SELECT table2.t1_id AS t1_id, count(table2.id) AS count "
+            "FROM table2 GROUP BY table2.t1_id) AS anon_1 "
+            "JOIN table1 ON anon_1.t1_id = table1.id"
         )
 
+
     def test_mapped_select_to_mapped_explicit_left(self):
         T1, T2 = self.classes.T1, self.classes.T2
 
index 19ff78004b88d5563925dc6c3500a41885e9b836..ed09e72c10554e48eeb438e233f585074347301c 100644 (file)
@@ -582,20 +582,16 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         self.assert_compile(
             q,
             "SELECT "
-            "anon_1.addresses_id AS anon_1_addresses_id, "
-            "anon_1.users_id AS anon_1_users_id, "
-            "anon_1.users_name AS anon_1_users_name, "
-            "anon_1.addresses_user_id AS anon_1_addresses_user_id, "
-            "anon_1.addresses_email_address AS "
-            "anon_1_addresses_email_address, "
-            "anon_1.users_name || :name_1 AS anon_2 "
-            "FROM addresses JOIN (SELECT users.id AS users_id, "
-            "users.name AS users_name, addresses.id AS addresses_id, "
-            "addresses.user_id AS addresses_user_id, "
-            "addresses.email_address AS addresses_email_address "
-            "FROM users JOIN addresses ON users.id = "
-            "addresses.user_id) AS anon_1 ON "
-            "anon_1.users_id = addresses.user_id"
+            "addresses_1.id AS addresses_1_id, "
+            "users_1.id AS users_1_id, "
+            "users_1.name AS users_1_name, "
+            "addresses_1.user_id AS addresses_1_user_id, "
+            "addresses_1.email_address AS "
+            "addresses_1_email_address, "
+            "users_1.name || :name_1 AS anon_1 "
+            "FROM addresses JOIN (users AS users_1 JOIN addresses AS addresses_1 ON users_1.id = "
+            "addresses_1.user_id) ON "
+            "users_1.id = addresses.user_id"
         )
 
     def test_column_prop_deannotate(self):
index 150673560504155c7fcb4568f0a08bff53c9fef6..d002fd50f734d932b38490af0c7cc62bedbb378a 100644 (file)
@@ -86,7 +86,7 @@ class _PolymorphicTestBase(object):
 
     def test_with_polymorphic_join_compile_one(self):
         sess = Session()
-
+# MARKMARK
         self.assert_compile(
             sess.query(Company).join(
                     Company.employees.of_type(
@@ -194,13 +194,14 @@ class PolymorphicPolymorphicTest(_PolymorphicTestBase, _PolymorphicPolymorphic):
     def _polymorphic_join_target(self, cls):
         from sqlalchemy.orm import class_mapper
 
+        from sqlalchemy.sql.expression import FromGrouping
         m, sel = class_mapper(Person)._with_polymorphic_args(cls)
-        sel = sel.alias()
+        sel = FromGrouping(sel.alias(flat=True))
         comp_sel = sel.compile(dialect=default.DefaultDialect())
 
         return \
             comp_sel.process(sel, asfrom=True).replace("\n", "") + \
-            " ON companies.company_id = anon_1.people_company_id"
+            " ON companies.company_id = people_1.company_id"
 
 class PolymorphicUnionsTest(_PolymorphicTestBase, _PolymorphicUnions):
 
@@ -228,13 +229,14 @@ class PolymorphicAliasedJoinsTest(_PolymorphicTestBase, _PolymorphicAliasedJoins
 class PolymorphicJoinsTest(_PolymorphicTestBase, _PolymorphicJoins):
     def _polymorphic_join_target(self, cls):
         from sqlalchemy.orm import class_mapper
+        from sqlalchemy.sql.expression import FromGrouping
 
-        sel = class_mapper(Person)._with_polymorphic_selectable.alias()
+        sel = FromGrouping(class_mapper(Person)._with_polymorphic_selectable.alias(flat=True))
         comp_sel = sel.compile(dialect=default.DefaultDialect())
 
         return \
             comp_sel.process(sel, asfrom=True).replace("\n", "") + \
-            " ON companies.company_id = anon_1.people_company_id"
+            " ON companies.company_id = people_1.company_id"
 
 
 class SubclassRelationshipTest(testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest):
@@ -453,6 +455,7 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, fixtures.DeclarativeM
                             DataContainer.jobs.of_type(Job_P).\
                                 any(Job_P.id < Job.id)
                         )
+
         self.assert_compile(q,
             "SELECT job.id AS job_id, job.type AS job_type, "
             "job.container_id "
@@ -460,11 +463,10 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, fixtures.DeclarativeM
             "FROM data_container "
             "JOIN job ON data_container.id = job.container_id "
             "WHERE EXISTS (SELECT 1 "
-            "FROM (SELECT job.id AS job_id, job.type AS job_type, "
-            "job.container_id AS job_container_id, "
-            "subjob.id AS subjob_id, subjob.attr AS subjob_attr "
-            "FROM job LEFT OUTER JOIN subjob ON job.id = subjob.id) AS anon_1 "
-            "WHERE data_container.id = anon_1.job_container_id AND job.id > anon_1.job_id)"
+            "FROM job AS job_1 LEFT OUTER JOIN subjob AS subjob_1 "
+                "ON job_1.id = subjob_1.id "
+            "WHERE data_container.id = job_1.container_id "
+            "AND job.id > job_1.id)"
         )
 
     def test_any_walias(self):
@@ -506,11 +508,10 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, fixtures.DeclarativeM
         self.assert_compile(q,
             "SELECT data_container.id AS data_container_id, "
             "data_container.name AS data_container_name "
-            "FROM data_container JOIN (SELECT job.id AS job_id, "
-            "job.type AS job_type, job.container_id AS job_container_id, "
-            "subjob.id AS subjob_id, subjob.attr AS subjob_attr "
-            "FROM job LEFT OUTER JOIN subjob ON job.id = subjob.id) "
-            "AS anon_1 ON data_container.id = anon_1.job_container_id")
+            "FROM data_container JOIN "
+            "(job AS job_1 LEFT OUTER JOIN subjob AS subjob_1 "
+                "ON job_1.id = subjob_1.id) "
+            "ON data_container.id = job_1.container_id")
 
     def test_join_wsubclass(self):
         ParentThing, DataContainer, Job, SubJob = \
@@ -547,11 +548,9 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, fixtures.DeclarativeM
         self.assert_compile(q,
             "SELECT data_container.id AS data_container_id, "
             "data_container.name AS data_container_name "
-            "FROM data_container JOIN (SELECT job.id AS job_id, "
-            "job.type AS job_type, job.container_id AS job_container_id, "
-            "subjob.id AS subjob_id, subjob.attr AS subjob_attr "
-            "FROM job JOIN subjob ON job.id = subjob.id) "
-            "AS anon_1 ON data_container.id = anon_1.job_container_id")
+            "FROM data_container JOIN "
+            "(job AS job_1 JOIN subjob AS subjob_1 ON job_1.id = subjob_1.id) "
+            "ON data_container.id = job_1.container_id")
 
     def test_join_walias(self):
         ParentThing, DataContainer, Job, SubJob = \
@@ -584,9 +583,8 @@ class SubclassRelationshipTest(testing.AssertsCompiledSQL, fixtures.DeclarativeM
         self.assert_compile(q,
             "SELECT data_container.id AS data_container_id, "
             "data_container.name AS data_container_name "
-            "FROM data_container JOIN (SELECT job.id AS job_id, "
-            "job.type AS job_type, job.container_id AS job_container_id, "
-            "subjob.id AS subjob_id, subjob.attr AS subjob_attr "
-            "FROM job LEFT OUTER JOIN subjob ON job.id = subjob.id) "
-            "AS anon_1 ON data_container.id = anon_1.job_container_id")
+            "FROM data_container JOIN "
+            "(job AS job_1 LEFT OUTER JOIN subjob AS subjob_1 "
+            "ON job_1.id = subjob_1.id) "
+            "ON data_container.id = job_1.container_id")