From 5c14b20f9f02179e4e59e3f196cbab5da8366583 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 12 Dec 2007 17:56:52 +0000 Subject: [PATCH] implemented many-to-one comparisons to None generate IS NULL, with column on the left side in all cases --- CHANGES | 4 ++++ lib/sqlalchemy/orm/properties.py | 5 ++++- lib/sqlalchemy/orm/strategies.py | 25 ++++++++++++++++++++++++- test/orm/query.py | 5 +++++ test/testlib/fixtures.py | 2 +- 5 files changed, 38 insertions(+), 3 deletions(-) diff --git a/CHANGES b/CHANGES index 38dec541b1..dcc6cbdb2f 100644 --- a/CHANGES +++ b/CHANGES @@ -79,6 +79,10 @@ CHANGES statements as well. Filter criterion, order bys, eager load clauses will be "aliased" against the given statement. + - query.filter(SomeClass.somechild == None), when comparing + a many-to-one property to None, properly generates "id IS NULL" + including that the NULL is on the right side. + - eagerload(), lazyload(), eagerload_all() take an optional second class-or-mapper argument, which will select the mapper to apply the option towards. This can select among other diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4d41556a07..b6d6cef638 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -232,7 +232,10 @@ class PropertyLoader(StrategizedProperty): class Comparator(PropComparator): def __eq__(self, other): if other is None: - return ~sql.exists([1], self.prop.primaryjoin) + if self.prop.uselist: + return ~sql.exists([1], self.prop.primaryjoin) + else: + return self.prop._optimized_compare(None) elif self.prop.uselist: if not hasattr(other, '__iter__'): raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().") diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d9390345e9..2ba9d6be1e 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -8,7 +8,7 @@ from sqlalchemy import sql, util, exceptions, logging from sqlalchemy.sql import util as sql_util -from sqlalchemy.sql import visitors +from sqlalchemy.sql import visitors, expression, operators from sqlalchemy.orm import mapper, attributes from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption from sqlalchemy.orm import session as sessionlib @@ -292,6 +292,9 @@ class LazyLoader(AbstractRelationLoader): self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i)) def lazy_clause(self, instance, reverse_direction=False): + if instance is None: + return self.lazy_none_clause(reverse_direction) + if not reverse_direction: (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse) else: @@ -305,6 +308,26 @@ class LazyLoader(AbstractRelationLoader): bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key]) return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam) + def lazy_none_clause(self, reverse_direction=False): + if not reverse_direction: + (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse) + else: + (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) + bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + + def visit_binary(binary): + mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent + if isinstance(binary.left, expression._BindParamClause) and binary.left.key in bind_to_col: + # reverse order if the NULL is on the left side + binary.left = binary.right + binary.right = expression.null() + binary.operator = operators.is_ + elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in bind_to_col: + binary.right = expression.null() + binary.operator = operators.is_ + + return visitors.traverse(criterion, clone=True, visit_binary=visit_binary) + def setup_loader(self, instance, options=None, path=None): if not mapper.has_mapper(instance): return None diff --git a/test/orm/query.py b/test/orm/query.py index 5f85151f0a..efd890e0e8 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -376,6 +376,11 @@ class FilterTest(QueryTest): assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all() + # generates an IS NULL + assert [] == sess.query(Address).filter(Address.user == None).all() + + assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all() + class AggregateTest(QueryTest): def test_sum(self): sess = create_session() diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py index 4394780bb2..2a4b457acd 100644 --- a/test/testlib/fixtures.py +++ b/test/testlib/fixtures.py @@ -150,7 +150,7 @@ def install_fixture_data(): dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4), dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1), dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4), - dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1) + dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=None) ) items.insert().execute( dict(id=1, description='item 1'), -- 2.47.3