]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The join condition produced by with_parent
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jul 2011 16:44:41 +0000 (12:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jul 2011 16:44:41 +0000 (12:44 -0400)
as well as when using a "dynamic" relationship
against a parent will generate unique
bindparams, rather than incorrectly repeating
the same bindparam.  [ticket:2207].

CHANGES
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 7929738aa8a45227175a766dd0602fb86a753de7..a995c4270386ad2432bb2f208012fb591372b7fb 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -11,6 +11,12 @@ CHANGES
     joinedload + LIMIT + order by the column
     property() occurred.  [ticket:2188].
 
+  - The join condition produced by with_parent
+    as well as when using a "dynamic" relationship
+    against a parent will generate unique
+    bindparams, rather than incorrectly repeating 
+    the same bindparam.  [ticket:2207].
+
   - Repaired the "no statement condition" 
     assertion in Query which would attempt
     to raise if a generative method were called
index 44cbb76586a694a6b4c9f2761850e8f029d36cdf..ac49d335d95411d214cc8ee779703c1b838ad1dc 100644 (file)
@@ -357,18 +357,24 @@ class LazyLoader(AbstractRelationshipLoader):
 
     def init(self):
         super(LazyLoader, self).init()
-        self.__lazywhere, \
-        self.__bind_to_col, \
+        self._lazywhere, \
+        self._bind_to_col, \
         self._equated_columns = self._create_lazy_clause(self.parent_property)
 
-        self.logger.info("%s lazy loading clause %s", self, self.__lazywhere)
+        self._rev_lazywhere, \
+        self._rev_bind_to_col, \
+        self._rev_equated_columns = self._create_lazy_clause(
+                                                self.parent_property, 
+                                                reverse_direction=True)
+
+        self.logger.info("%s lazy loading clause %s", self, self._lazywhere)
 
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         #from sqlalchemy.orm import query
         self.use_get = not self.uselist and \
                         self.mapper._get_clause[0].compare(
-                            self.__lazywhere, 
+                            self._lazywhere, 
                             use_proxies=True, 
                             equivalents=self.mapper._equivalent_columns
                         )
@@ -414,14 +420,14 @@ class LazyLoader(AbstractRelationshipLoader):
 
         if not reverse_direction:
             criterion, bind_to_col, rev = \
-                                            self.__lazywhere, \
-                                            self.__bind_to_col, \
+                                            self._lazywhere, \
+                                            self._bind_to_col, \
                                             self._equated_columns
         else:
             criterion, bind_to_col, rev = \
-                                LazyLoader._create_lazy_clause(
-                                        self.parent_property,
-                                        reverse_direction=reverse_direction)
+                                            self._rev_lazywhere, \
+                                            self._rev_bind_to_col, \
+                                            self._rev_equated_columns
 
         if reverse_direction:
             mapper = self.parent_property.mapper
@@ -437,15 +443,18 @@ class LazyLoader(AbstractRelationshipLoader):
         sess = sessionlib._state_session(state)
         if sess is not None and sess._flushing:
             def visit_bindparam(bindparam):
-                if bindparam.key in bind_to_col:
+                if bindparam._identifying_key in bind_to_col:
                     bindparam.value = \
-                                lambda: mapper._get_committed_state_attr_by_column(
-                                        state, dict_, bind_to_col[bindparam.key])
+                        lambda: mapper._get_committed_state_attr_by_column(
+                                state, dict_, 
+                                bind_to_col[bindparam._identifying_key])
         else:
             def visit_bindparam(bindparam):
-                if bindparam.key in bind_to_col:
-                    bindparam.value = lambda: mapper._get_state_attr_by_column(
-                                            state, dict_, bind_to_col[bindparam.key])
+                if bindparam._identifying_key in bind_to_col:
+                    bindparam.value = \
+                        lambda: mapper._get_state_attr_by_column(
+                                state, dict_, 
+                                bind_to_col[bindparam._identifying_key])
 
 
         if self.parent_property.secondary is not None and alias_secondary:
@@ -463,14 +472,14 @@ class LazyLoader(AbstractRelationshipLoader):
     def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
         if not reverse_direction:
             criterion, bind_to_col, rev = \
-                                        self.__lazywhere, \
-                                        self.__bind_to_col,\
+                                        self._lazywhere, \
+                                        self._bind_to_col,\
                                         self._equated_columns
         else:
             criterion, bind_to_col, rev = \
-                                LazyLoader._create_lazy_clause(
-                                    self.parent_property,
-                                    reverse_direction=reverse_direction)
+                                            self._rev_lazywhere, \
+                                            self._rev_bind_to_col, \
+                                            self._rev_equated_columns
 
         criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col)
 
@@ -535,7 +544,7 @@ class LazyLoader(AbstractRelationshipLoader):
                     if equated in binds:
                         return None
                 if col not in binds:
-                    binds[col] = sql.bindparam(None, None, type_=col.type)
+                    binds[col] = sql.bindparam(None, None, type_=col.type, unique=True)
                 return binds[col]
             return None
 
index f2c9bc7af1f2d59d308f60413c8513c4d8e8b870..0035b57cd584b3d3f769c60f357191eb0745db6c 100644 (file)
@@ -2444,7 +2444,16 @@ class _BindParamClause(ColumnElement):
         else:
             self.key = key or _generated_label('%%(%d param)s'
                     % id(self))
+
+        # identifiying key that won't change across
+        # clones, used to identify the bind's logical
+        # identity
+        self._identifying_key = self.key
+
+        # key that was passed in the first place, used to 
+        # generate new keys
         self._orig_key = key or 'param'
+
         self.unique = unique
         self.value = value
         self.isoutparam = isoutparam
index a7ebf15c5f79bb242b450966cfe76749c86f4a12..ef8502f690be4ec450422fc8c2ed9795d59c357e 100644 (file)
@@ -193,13 +193,15 @@ def adapt_criterion_to_null(crit, nulls):
     """given criterion containing bind params, convert selected elements to IS NULL."""
 
     def visit_binary(binary):
-        if isinstance(binary.left, expression._BindParamClause) and binary.left.key in nulls:
+        if isinstance(binary.left, expression._BindParamClause) \
+            and binary.left._identifying_key in nulls:
             # reverse order if the NULL is on the left side
             binary.left = binary.right
             binary.right = expression.null()
             binary.operator = operators.is_
             binary.negate = operators.isnot
-        elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in nulls:
+        elif isinstance(binary.right, expression._BindParamClause) \
+            and binary.right._identifying_key in nulls:
             binary.right = expression.null()
             binary.operator = operators.is_
             binary.negate = operators.isnot
index d8e1a4d70bfb73537b6f8fbf742f45b3aa110cbb..d8c2c19233cbb876351ae2c89fce86bcfe4c1a7d 100644 (file)
@@ -831,7 +831,9 @@ class SliceTest(QueryTest):
 
 
 
-class FilterTest(QueryTest):
+class FilterTest(QueryTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
     def test_basic(self):
         assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all()
 
@@ -879,6 +881,23 @@ class FilterTest(QueryTest):
 
         #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
 
+    def test_unique_binds_join_cond(self):
+        """test that binds used when the lazyclause is used in criterion are unique"""
+
+        sess = Session()
+        a1, a2 = sess.query(Address).order_by(Address.id)[0:2]
+        self.assert_compile(
+            sess.query(User).filter(User.addresses.contains(a1)).union(
+                sess.query(User).filter(User.addresses.contains(a2))
+            ),
+            "SELECT anon_1.users_id AS anon_1_users_id, anon_1.users_name AS "
+            "anon_1_users_name FROM (SELECT users.id AS users_id, "
+            "users.name AS users_name FROM users WHERE users.id = :param_1 "
+            "UNION SELECT users.id AS users_id, users.name AS users_name "
+            "FROM users WHERE users.id = :param_2) AS anon_1",
+            checkparams = {u'param_1': 7, u'param_2': 8}
+        )
+
     def test_any(self):
         sess = create_session()
 
@@ -1467,7 +1486,9 @@ class TextTest(QueryTest):
 
         eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')])
 
-class ParentTest(QueryTest):
+class ParentTest(QueryTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
     def test_o2m(self):
         sess = create_session()
         q = sess.query(User)
@@ -1921,6 +1942,7 @@ class AddEntityEquivalenceTest(_base.MappedTest, AssertsCompiledSQL):
             )
 
 class JoinTest(QueryTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
 
     def test_single_name(self):
         sess = create_session()
@@ -2672,8 +2694,45 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             use_default_dialect=True
         )
 
+    def test_unique_binds_union(self):
+        """bindparams used in the 'parent' query are unique"""
 
+        sess = Session()
+        u1, u2 = sess.query(User).order_by(User.id)[0:2]
 
+        q1 = sess.query(Address).with_parent(u1, 'addresses')
+        q2 = sess.query(Address).with_parent(u2, 'addresses')
+
+        self.assert_compile(
+            q1.union(q2),
+            "SELECT anon_1.addresses_id AS anon_1_addresses_id, "
+            "anon_1.addresses_user_id AS anon_1_addresses_user_id, "
+            "anon_1.addresses_email_address AS "
+            "anon_1_addresses_email_address FROM (SELECT addresses.id AS "
+            "addresses_id, addresses.user_id AS addresses_user_id, "
+            "addresses.email_address AS addresses_email_address FROM "
+            "addresses WHERE :param_1 = addresses.user_id UNION SELECT "
+            "addresses.id AS addresses_id, addresses.user_id AS "
+            "addresses_user_id, addresses.email_address AS addresses_email_address "
+            "FROM addresses WHERE :param_2 = addresses.user_id) AS anon_1",
+            checkparams={u'param_1': 7, u'param_2': 8},
+        )
+
+    def test_unique_binds_or(self):
+
+        sess = Session()
+        u1, u2 = sess.query(User).order_by(User.id)[0:2]
+
+        self.assert_compile(
+            sess.query(Address).filter(
+                or_(with_parent(u1, 'addresses'), with_parent(u2, 'addresses'))
+            ),
+            "SELECT addresses.id AS addresses_id, addresses.user_id AS "
+            "addresses_user_id, addresses.email_address AS "
+            "addresses_email_address FROM addresses WHERE "
+            ":param_1 = addresses.user_id OR :param_2 = addresses.user_id",
+            checkparams={u'param_1': 7, u'param_2': 8},
+        )
 
     def test_from_self_resets_joinpaths(self):
         """test a join from from_self() doesn't confuse joins inside the subquery