]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- "not equals" comparisons of simple many-to-one relation
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Oct 2008 17:34:52 +0000 (17:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Oct 2008 17:34:52 +0000 (17:34 +0000)
to an instance will not drop into an EXISTS clause
and will compare foreign key columns instead.

- removed not-really-working use cases of comparing
a collection to an iterable.  Use contains() to test
for collection membership.

- Further simplified SELECT compilation and its relationship
to result row processing.

- Direct execution of a union() construct will properly set up
result-row processing. [ticket:1194]

CHANGES
lib/sqlalchemy/databases/maxdb.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
test/orm/query.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index 19c2d40b377eb62ec0307b6090541fa5aba1009e..29a52fb3faa5a2ba957d7e868f09943942bae4e2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -59,11 +59,25 @@ CHANGES
     - Adjustment to Session's post-flush accounting of newly
       "clean" objects to better protect against operating on
       objects as they're asynchronously gc'ed. [ticket:1182]
+    
+    - "not equals" comparisons of simple many-to-one relation
+      to an instance will not drop into an EXISTS clause
+      and will compare foreign key columns instead.
       
+    - removed not-really-working use cases of comparing 
+      a collection to an iterable.  Use contains() to test
+      for collection membership.
+        
 - sql
     - column.in_(someselect) can now be used as a columns-clause
       expression without the subquery bleeding into the FROM clause
       [ticket:1074]
+      
+    - Further simplified SELECT compilation and its relationship
+      to result row processing.
+    
+    - Direct execution of a union() construct will properly set up
+      result-row processing. [ticket:1194]
 
 - sqlite
     - Overhauled SQLite date/time bind/result processing to use
index 34629b298aa4336773210a30620ae27b155f8302..0e7310ab6086704ae4fb31e18e330e13a35562e8 100644 (file)
@@ -829,7 +829,7 @@ class MaxDBCompiler(compiler.DefaultCompiler):
 
         # No ORDER BY in subqueries.
         if order_by:
-            if self.is_subquery(select):
+            if self.is_subquery():
                 # It's safe to simply drop the ORDER BY if there is no
                 # LIMIT.  Right?  Other dialects seem to get away with
                 # dropping order.
@@ -845,7 +845,7 @@ class MaxDBCompiler(compiler.DefaultCompiler):
     def get_select_precolumns(self, select):
         # Convert a subquery's LIMIT to TOP
         sql = select._distinct and 'DISTINCT ' or ''
-        if self.is_subquery(select) and select._limit:
+        if self.is_subquery() and select._limit:
             if select._offset:
                 raise exc.InvalidRequestError(
                     'MaxDB does not support LIMIT with an offset.')
@@ -855,7 +855,7 @@ class MaxDBCompiler(compiler.DefaultCompiler):
     def limit_clause(self, select):
         # The docs say offsets are supported with LIMIT.  But they're not.
         # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate?
-        if self.is_subquery(select):
+        if self.is_subquery():
             # sub queries need TOP
             return ''
         elif select._offset:
index 4c5ad1fd11026c61902d7f2eea70220b69eaffae..42743870a0a42e0b48d6c264aea407992bde4cb4 100644 (file)
@@ -994,7 +994,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         order_by = self.process(select._order_by_clause)
 
         # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
-        if order_by and (not self.is_subquery(select) or select._limit):
+        if order_by and (not self.is_subquery() or select._limit):
             return " ORDER BY " + order_by
         else:
             return ""
index 5c64ec1aecee7371e4b8cf68c945786253627672..b464a3bcbe5e9050403d294f2f723a6eb33d7878 100644 (file)
@@ -798,7 +798,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
         order_by = self.process(select._order_by_clause)
 
         # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
-        if order_by and (not self.is_subquery(select) or select._limit):
+        if order_by and (not self.is_subquery() or select._limit):
             return " ORDER BY " + order_by
         else:
             return ""
index d2f5dae0c8fbe72014380c4e2ee5aa808f4fdc28..40bca8a119194f3adbf4b5fa6cca07e3875e05e9 100644 (file)
@@ -347,18 +347,7 @@ class PropertyLoader(StrategizedProperty):
                 else:
                     return self.prop._optimized_compare(None)
             elif self.prop.uselist:
-                if not hasattr(other, '__iter__'):
-                    raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object.  Use contains().")
-                else:
-                    j = self.prop.primaryjoin
-                    if self.prop.secondaryjoin:
-                        j = j & self.prop.secondaryjoin
-                    clauses = []
-                    for o in other:
-                        clauses.append(
-                            sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))]))
-                        )
-                    return sql.and_(*clauses)
+                raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
             else:
                 return self.prop._optimized_compare(other)
 
@@ -418,25 +407,30 @@ class PropertyLoader(StrategizedProperty):
             return clause
 
         def __negated_contains_or_equals(self, other):
+            if self.prop.direction == MANYTOONE:
+                state = attributes.instance_state(other)
+                strategy = self.prop._get_strategy(strategies.LazyLoader)
+                if strategy.use_get:
+                    return sql.and_(*[
+                        sql.or_(
+                        x !=
+                        self.prop.mapper._get_committed_state_attr_by_column(state, y),
+                        x == None)
+                        for (x, y) in self.prop.local_remote_pairs])
+                    
             criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
             return ~self._criterion_exists(criterion)
 
         def __ne__(self, other):
-            # TODO: simplify MANYTOONE comparsion when 
-            # the 'use_get' flag is enabled
-            
             if other is None:
                 if self.prop.direction == MANYTOONE:
                     return sql.or_(*[x!=None for x in self.prop._foreign_keys])
-                elif self.prop.uselist:
-                    return self.any()
                 else:
-                    return self.has()
-
-            if self.prop.uselist and not hasattr(other, '__iter__'):
-                raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object")
-
-            return self.__negated_contains_or_equals(other)
+                    return self._criterion_exists()
+            elif self.prop.uselist:
+                raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
+            else:
+                return self.__negated_contains_or_equals(other)
 
     def compare(self, op, value, value_is_parent=False):
         if op == operators.eq:
index df994d689fa716b7e8e0eca1bc6f79080ddca42a..d66a51de42cc285cc923f5954976f747f5141c1a 100644 (file)
@@ -65,20 +65,20 @@ class SchemaItem(object):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
+    @property
     def bind(self):
         """Return the connectable associated with this SchemaItem."""
 
         m = self.metadata
         return m and m.bind or None
-    bind = property(bind)
 
+    @property
     def info(self):
         try:
             return self._info
         except AttributeError:
             self._info = {}
             return self._info
-    info = property(info)
 
 
 def _get_table_key(name, schema):
@@ -291,9 +291,9 @@ class Table(SchemaItem, expression.TableClause):
     def __post_init(self, *args, **kwargs):
         self._init_items(*args)
 
+    @property
     def key(self):
         return _get_table_key(self.name, self.schema)
-    key = property(key)
 
     def _set_primary_key(self, pk):
         if getattr(self, '_primary_key', None) in self.constraints:
index 2982a1759629119b23e8e7e3501e52224a75a983..57345349927904a8d41f793954673edb90ec2ea9 100644 (file)
@@ -174,19 +174,13 @@ class DefaultCompiler(engine.Compiled):
     def compile(self):
         self.string = self.process(self.statement)
 
-    def process(self, obj, stack=None, **kwargs):
-        if stack:
-            self.stack.append(stack)
-        try:
-            meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
-            if meth:
-                return meth(obj, **kwargs)
-        finally:
-            if stack:
-                self.stack.pop(-1)
+    def process(self, obj, **kwargs):
+        meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+        if meth:
+            return meth(obj, **kwargs)
 
-    def is_subquery(self, select):
-        return self.stack and self.stack[-1].get('is_subquery')
+    def is_subquery(self):
+        return self.stack and self.stack[-1].get('from')
 
     def construct_params(self, params=None):
         """return a dictionary of bind parameter keys and values"""
@@ -342,16 +336,9 @@ class DefaultCompiler(engine.Compiled):
         return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
 
     def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
-        stack_entry = {'select':cs}
-
-        if asfrom:
-            stack_entry['is_subquery'] = True
-        elif self.stack and self.stack[-1].get('select'):
-            stack_entry['is_subquery'] = True
-        self.stack.append(stack_entry)
 
-        text = string.join((self.process(c, asfrom=asfrom, parens=False)
-                            for c in cs.selects),
+        text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+                            for i, c in enumerate(cs.selects)),
                            " " + cs.keyword + " ")
         group_by = self.process(cs._group_by_clause, asfrom=asfrom)
         if group_by:
@@ -360,8 +347,6 @@ class DefaultCompiler(engine.Compiled):
         text += self.order_by_clause(cs)
         text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
 
-        self.stack.pop(-1)
-
         if asfrom and parens:
             return "(" + text + ")"
         else:
@@ -470,28 +455,11 @@ class DefaultCompiler(engine.Compiled):
         else:
             return column
 
-    def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, **kwargs):
+    def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs):
 
-        stack_entry = {'select':select}
-        prev_entry = self.stack and self.stack[-1] or None
-
-        if asfrom or (prev_entry and 'select' in prev_entry):
-            stack_entry['is_subquery'] = True
-            stack_entry['iswrapper'] = iswrapper
-            if not iswrapper and prev_entry and 'iswrapper' in prev_entry:
-                column_clause_args = {'result_map':self.result_map}
-            else:
-                column_clause_args = {}
-        elif iswrapper:
-            column_clause_args = {}
-            stack_entry['iswrapper'] = True
-        else:
-            column_clause_args = {'result_map':self.result_map}
-
-        if self.stack and 'from' in self.stack[-1]:
-            existingfroms = self.stack[-1]['from']
-        else:
-            existingfroms = None
+        entry = self.stack and self.stack[-1] or {}
+        
+        existingfroms = entry.get('from', None)
 
         froms = select._get_display_froms(existingfroms)
 
@@ -499,10 +467,15 @@ class DefaultCompiler(engine.Compiled):
 
         # TODO: might want to propigate existing froms for select(select(select))
         # where innermost select should correlate to outermost
-#        if existingfroms:
-#            correlate_froms = correlate_froms.union(existingfroms)
-        stack_entry['from'] = correlate_froms
-        self.stack.append(stack_entry)
+        # if existingfroms:
+        #     correlate_froms = correlate_froms.union(existingfroms)
+
+        if compound_index==1 and not entry or entry.get('iswrapper', False):
+            column_clause_args = {'result_map':self.result_map}
+        else:
+            column_clause_args = {}
+
+        self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper})
 
         # the actual list of columns to print in the SELECT column list.
         inner_columns = util.OrderedSet(
@@ -520,13 +493,9 @@ class DefaultCompiler(engine.Compiled):
         text += self.get_select_precolumns(select)
         text += ', '.join(inner_columns)
 
-        from_strings = []
-        for f in froms:
-            from_strings.append(self.process(f, asfrom=True))
-
         if froms:
             text += " \nFROM "
-            text += ', '.join(from_strings)
+            text += ', '.join(self.process(f, asfrom=True) for f in froms)
         else:
             text += self.default_from()
 
index 151bada633c167d7b3f1c14634880923f45616d8..567ca317c9b639ea0b027734f64f6f4802b409e7 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import exc as sa_exc, util
 from sqlalchemy.sql import compiler
 from sqlalchemy.engine import default
 from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes
 
 from testlib import *
 from orm import _base
@@ -334,7 +335,16 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
                         "WHERE users.id = addresses.user_id AND addresses.id = :id_1)"
                     )
 
-        self._test(Address.user == User(id=7), ":param_1 = addresses.user_id")
+        u7 = User(id=7)
+        attributes.instance_state(u7).commit_all()
+        
+        self._test(Address.user == u7, ":param_1 = addresses.user_id")
+
+        self._test(Address.user != u7, "addresses.user_id != :user_id_1 OR addresses.user_id IS NULL")
+
+        self._test(Address.user == None, "addresses.user_id IS NULL")
+
+        self._test(Address.user != None, "addresses.user_id IS NOT NULL")
 
     def test_selfref_relation(self):
         nalias = aliased(Node)
index 5d10c57501e90d3739223cdb860cc4da8863f6ab..793695919c37a374c012d90cd2ef457a098d5470 100644 (file)
@@ -288,7 +288,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
     def tearDown(self):
         unicode_table.delete().execute()
 
-    def testbasic(self):
+    def test_round_trip(self):
         assert unicode_table.c.unicode_varchar.type.length == 250
         rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
         unicodedata = rawdata.decode('utf-8')
@@ -296,10 +296,6 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
                                        unicode_text=unicodedata,
                                        plain_varchar=rawdata)
         x = unicode_table.select().execute().fetchone()
-        print 0, repr(unicodedata)
-        print 1, repr(x['unicode_varchar'])
-        print 2, repr(x['unicode_text'])
-        print 3, repr(x['plain_varchar'])
         self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
         self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
         if isinstance(x['plain_varchar'], unicode):
@@ -310,7 +306,21 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
         else:
             self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
 
-    def testassert(self):
+    def test_union(self):
+        """ensure compiler processing works for UNIONs"""
+
+        rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
+        unicodedata = rawdata.decode('utf-8')
+        unicode_table.insert().execute(unicode_varchar=unicodedata,
+                                       unicode_text=unicodedata,
+                                       plain_varchar=rawdata)
+                                       
+        x = union(unicode_table.select(), unicode_table.select()).execute().fetchone()
+        self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
+        self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
+        
+
+    def test_assertions(self):
         try:
             unicode_table.insert().execute(unicode_varchar='not unicode')
             assert False
@@ -337,11 +347,11 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
             unicode_engine.dispose()
 
     @testing.fails_on('oracle')
-    def testblanks(self):
+    def test_blank_strings(self):
         unicode_table.insert().execute(unicode_varchar=u'')
         assert select([unicode_table.c.unicode_varchar]).scalar() == u''
 
-    def testengineparam(self):
+    def test_engine_parameter(self):
         """tests engine-wide unicode conversion"""
         prev_unicode = testing.db.engine.dialect.convert_unicode
         prev_assert = testing.db.engine.dialect.assert_unicode
@@ -367,7 +377,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
 
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
     @testing.fails_on('firebird') # "Data type unknown" on the parameter
-    def testlength(self):
+    def test_length_function(self):
         """checks the database correctly understands the length of a unicode string"""
         teststr = u'aaa\x1234'
         self.assert_(testing.db.func.length(teststr).scalar() == len(teststr))