]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- query.get() can be used with a mapping to an outer join
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 9 Nov 2009 23:20:31 +0000 (23:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 9 Nov 2009 23:20:31 +0000 (23:20 +0000)
where one or more of the primary key values are None.
[ticket:1135]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/util.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 2116b9172cf29b5eb21835f73ea9955168c61e64..a59bf971e51fc6aea451c5751339c25aa9efaa60 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -86,6 +86,10 @@ CHANGES
   - query.select_from() accepts multiple clauses to produce 
     multiple comma separated entries within the FROM clause.
     Useful when selecting from multiple-homed join() clauses.
+
+  - query.get() can be used with a mapping to an outer join
+    where one or more of the primary key values are None.
+    [ticket:1135]
     
   - query.from_self(), query.union(), others which do a 
     "SELECT * from (SELECT...)" type of nesting will do
index a4f85f7b17c5bebd5082047f5ccfd76325422c7c..1556db7d25f33272bde13409d86694dcef89e822 100644 (file)
@@ -1402,6 +1402,7 @@ class Query(object):
     def _get(self, key=None, ident=None, refresh_state=None, lockmode=None,
                                         only_load_props=None, passive=None):
         lockmode = lockmode or self._lockmode
+        
         if not self._populate_existing and not refresh_state and \
                 not self._mapper_zero().always_refresh and lockmode is None:
             instance = self.session.identity_map.get(key)
@@ -1436,7 +1437,17 @@ class Query(object):
             mapper = q._mapper_zero()
             params = {}
             (_get_clause, _get_params) = mapper._get_clause
-
+            
+            # None present in ident - turn those comparisons
+            # into "IS NULL"
+            if None in ident:
+                nones = set([
+                            _get_params[col].key for col, value in
+                             zip(mapper.primary_key, ident) if value is None
+                            ])
+                _get_clause = sql_util.adapt_criterion_to_null(
+                                                _get_clause, nones)
+                
             _get_clause = q._adapt_clause(_get_clause, True, False)
             q._criterion = _get_clause
 
index 08c60a1d2376b15ed9e11eba30ed601630d56f75..ff80fd6e6b87999f985dfa38e8cfb25cf86faa0c 100644 (file)
@@ -412,8 +412,12 @@ class LazyLoader(AbstractRelationLoader):
         else:
             (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
+        if reverse_direction:
+            mapper = self.parent_property.mapper
+        else:
+            mapper = self.parent_property.parent
+
         def visit_bindparam(bindparam):
-            mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
             if bindparam.key in bind_to_col:
                 # use the "committed" (database) version to get query column values
                 # also its a deferred value; so that when used by Query, the committed value is used
@@ -435,20 +439,8 @@ class LazyLoader(AbstractRelationLoader):
         else:
             (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
-        def visit_binary(binary):
-            mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
-            if isinstance(binary.left, expression._BindParamClause) and binary.left.key in bind_to_col:
-                # reverse order if the NULL is on the left side
-                binary.left = binary.right
-                binary.right = expression.null()
-                binary.operator = operators.is_
-                binary.negate = operators.isnot
-            elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in bind_to_col:
-                binary.right = expression.null()
-                binary.operator = operators.is_
-                binary.negate = operators.isnot
-
-        criterion = visitors.cloned_traverse(criterion, {}, {'binary':visit_binary})
+        criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col)
+
         if adapt_source:
             criterion = adapt_source(criterion)
         return criterion
index d1265c75f385b301edb7f66121df9a6876f45117..a84a3eb7477a2dbf43e6e5cf770ffba4077610e5 100644 (file)
@@ -79,6 +79,24 @@ def find_columns(clause):
     visitors.traverse(clause, {}, {'column':cols.add})
     return cols
 
+def adapt_criterion_to_null(crit, nulls):
+    """given criterion containing bind params, convert selected elements to IS NULL."""
+
+    def visit_binary(binary):
+        if isinstance(binary.left, expression._BindParamClause) and binary.left.key in nulls:
+            # reverse order if the NULL is on the left side
+            binary.left = binary.right
+            binary.right = expression.null()
+            binary.operator = operators.is_
+            binary.negate = operators.isnot
+        elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in nulls:
+            binary.right = expression.null()
+            binary.operator = operators.is_
+            binary.negate = operators.isnot
+
+    return visitors.cloned_traverse(crit, {}, {'binary':visit_binary})
+    
+    
 def join_condition(a, b, ignore_nonexistent_tables=False):
     """create a join condition between two tables.
     
index 98e85fa4110240197e88bc0efb6219a3ccc27f73..83550b060bf82f4c2b05e5c17fb993961b849bdf 100644 (file)
@@ -90,7 +90,25 @@ class GetTest(QueryTest):
         assert one_two.k == 3
         q = s.query(CompositePk)
         assert_raises(sa_exc.InvalidRequestError, q.get, 7)        
+    
+    def test_get_null_pk(self):
+        """test that a mapping which can have None in a 
+        PK (i.e. map to an outerjoin) works with get()."""
+        
+        s = users.outerjoin(addresses)
         
+        class UserThing(_base.ComparableEntity):
+            pass
+            
+        mapper(UserThing, s, properties={
+            'id':(users.c.id, addresses.c.user_id),
+            'address_id':addresses.c.id,
+        })
+        sess = create_session()
+        u10 = sess.query(UserThing).get((10, None))
+        eq_(u10,
+            UserThing(id=10)
+        )
 
     def test_no_criterion(self):
         """test that get()/load() does not use preexisting filter/etc. criterion"""