From: Mike Bayer Date: Mon, 9 Nov 2009 23:20:31 +0000 (+0000) Subject: - query.get() can be used with a mapping to an outer join X-Git-Tag: rel_0_6beta1~178 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c6724a3ff0c88d77431d7e29ed2747eb90953c95;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - query.get() can be used with a mapping to an outer join where one or more of the primary key values are None. [ticket:1135] --- diff --git a/CHANGES b/CHANGES index 2116b9172c..a59bf971e5 100644 --- a/CHANGES +++ b/CHANGES @@ -86,6 +86,10 @@ CHANGES - query.select_from() accepts multiple clauses to produce multiple comma separated entries within the FROM clause. Useful when selecting from multiple-homed join() clauses. + + - query.get() can be used with a mapping to an outer join + where one or more of the primary key values are None. + [ticket:1135] - query.from_self(), query.union(), others which do a "SELECT * from (SELECT...)" type of nesting will do diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a4f85f7b17..1556db7d25 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1402,6 +1402,7 @@ class Query(object): def _get(self, key=None, ident=None, refresh_state=None, lockmode=None, only_load_props=None, passive=None): lockmode = lockmode or self._lockmode + if not self._populate_existing and not refresh_state and \ not self._mapper_zero().always_refresh and lockmode is None: instance = self.session.identity_map.get(key) @@ -1436,7 +1437,17 @@ class Query(object): mapper = q._mapper_zero() params = {} (_get_clause, _get_params) = mapper._get_clause - + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in ident: + nones = set([ + _get_params[col].key for col, value in + zip(mapper.primary_key, ident) if value is None + ]) + _get_clause = sql_util.adapt_criterion_to_null( + _get_clause, nones) + _get_clause = q._adapt_clause(_get_clause, True, False) q._criterion = _get_clause diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 08c60a1d23..ff80fd6e6b 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -412,8 +412,12 @@ class LazyLoader(AbstractRelationLoader): else: (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) + if reverse_direction: + mapper = self.parent_property.mapper + else: + mapper = self.parent_property.parent + def visit_bindparam(bindparam): - mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent if bindparam.key in bind_to_col: # use the "committed" (database) version to get query column values # also its a deferred value; so that when used by Query, the committed value is used @@ -435,20 +439,8 @@ class LazyLoader(AbstractRelationLoader): else: (criterion, bind_to_col, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) - def visit_binary(binary): - mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent - if isinstance(binary.left, expression._BindParamClause) and binary.left.key in bind_to_col: - # reverse order if the NULL is on the left side - binary.left = binary.right - binary.right = expression.null() - binary.operator = operators.is_ - binary.negate = operators.isnot - elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in bind_to_col: - binary.right = expression.null() - binary.operator = operators.is_ - binary.negate = operators.isnot - - criterion = visitors.cloned_traverse(criterion, {}, {'binary':visit_binary}) + criterion = sql_util.adapt_criterion_to_null(criterion, bind_to_col) + if adapt_source: criterion = adapt_source(criterion) return criterion diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d1265c75f3..a84a3eb747 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -79,6 +79,24 @@ def find_columns(clause): visitors.traverse(clause, {}, {'column':cols.add}) return cols +def adapt_criterion_to_null(crit, nulls): + """given criterion containing bind params, convert selected elements to IS NULL.""" + + def visit_binary(binary): + if isinstance(binary.left, expression._BindParamClause) and binary.left.key in nulls: + # reverse order if the NULL is on the left side + binary.left = binary.right + binary.right = expression.null() + binary.operator = operators.is_ + binary.negate = operators.isnot + elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in nulls: + binary.right = expression.null() + binary.operator = operators.is_ + binary.negate = operators.isnot + + return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) + + def join_condition(a, b, ignore_nonexistent_tables=False): """create a join condition between two tables. diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 98e85fa411..83550b060b 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -90,7 +90,25 @@ class GetTest(QueryTest): assert one_two.k == 3 q = s.query(CompositePk) assert_raises(sa_exc.InvalidRequestError, q.get, 7) + + def test_get_null_pk(self): + """test that a mapping which can have None in a + PK (i.e. map to an outerjoin) works with get().""" + + s = users.outerjoin(addresses) + class UserThing(_base.ComparableEntity): + pass + + mapper(UserThing, s, properties={ + 'id':(users.c.id, addresses.c.user_id), + 'address_id':addresses.c.id, + }) + sess = create_session() + u10 = sess.query(UserThing).get((10, None)) + eq_(u10, + UserThing(id=10) + ) def test_no_criterion(self): """test that get()/load() does not use preexisting filter/etc. criterion"""