]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use base __ne__ implementation for range types w/ None
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Apr 2018 19:35:00 +0000 (15:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Apr 2018 19:36:48 +0000 (15:36 -0400)
Fixed bug where the special "not equals" operator for the Postgresql
"range" datatypes such as DATERANGE would fail to render "IS NOT NULL" when
compared to the Python ``None`` value.

Also break up range tests into backend round trip and straight
compilation suites.

Change-Id: Ibaee132b1ea7dac8b799495a27f98f82a7d9c028
Fixes: #4229
doc/build/changelog/unreleased_12/4229.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_12/4229.rst b/doc/build/changelog/unreleased_12/4229.rst
new file mode 100644 (file)
index 0000000..d28146e
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 4229
+
+    Fixed bug where the special "not equals" operator for the Postgresql
+    "range" datatypes such as DATERANGE would fail to render "IS NOT NULL" when
+    compared to the Python ``None`` value.
+
+
index 38bfb37d448d8ca31c818bd1e5a3c1431402fd0a..eb2d86bbdc8481b1ac5b573fed74cc12181d6bbc 100644 (file)
@@ -33,7 +33,11 @@ class RangeOperators(object):
 
         def __ne__(self, other):
             "Boolean expression. Returns true if two ranges are not equal"
-            return self.expr.op('<>')(other)
+            if other is None:
+                return super(
+                    RangeOperators.comparator_factory, self).__ne__(other)
+            else:
+                return self.expr.op('<>')(other)
 
         def contains(self, other, **kw):
             """Boolean expression. Returns true if the right hand operand,
index 8aa9d1b1f59a5e902a6235e05b68d249ded6d7d5..f5108920db4208baf8b9de9f13b7e0fd3fb1807e 100644 (file)
@@ -2327,63 +2327,22 @@ class HStoreRoundTripTest(fixtures.TablesTest):
         )
 
 
-class _RangeTypeMixin(object):
-    __requires__ = 'range_types', 'psycopg2_compatibility'
-    __backend__ = True
+class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
+    __dialect__ = 'postgresql'
 
-    def extras(self):
-        # done this way so we don't get ImportErrors with
-        # older psycopg2 versions.
-        if testing.against("postgresql+psycopg2cffi"):
-            from psycopg2cffi import extras
-        else:
-            from psycopg2 import extras
-        return extras
+    # operator tests
 
     @classmethod
-    def define_tables(cls, metadata):
-        # no reason ranges shouldn't be primary keys,
-        # so lets just use them as such
-        table = Table('data_table', metadata,
+    def setup_class(cls):
+        table = Table('data_table', MetaData(),
                       Column('range', cls._col_type, primary_key=True),
                       )
         cls.col = table.c.range
 
-    def test_actual_type(self):
-        eq_(str(self._col_type()), self._col_str)
-
-    def test_reflect(self):
-        from sqlalchemy import inspect
-        insp = inspect(testing.db)
-        cols = insp.get_columns('data_table')
-        assert isinstance(cols[0]['type'], self._col_type)
-
-    def _assert_data(self):
-        data = testing.db.execute(
-            select([self.tables.data_table.c.range])
-        ).fetchall()
-        eq_(data, [(self._data_obj(), )])
-
-    def test_insert_obj(self):
-        testing.db.engine.execute(
-            self.tables.data_table.insert(),
-            {'range': self._data_obj()}
-        )
-        self._assert_data()
-
-    def test_insert_text(self):
-        testing.db.engine.execute(
-            self.tables.data_table.insert(),
-            {'range': self._data_str}
-        )
-        self._assert_data()
-
-    # operator tests
-
     def _test_clause(self, colclause, expected):
-        dialect = postgresql.dialect()
-        compiled = str(colclause.compile(dialect=dialect))
-        eq_(compiled, expected)
+        self.assert_compile(
+            colclause, expected
+        )
 
     def test_where_equal(self):
         self._test_clause(
@@ -2397,6 +2356,18 @@ class _RangeTypeMixin(object):
             "data_table.range <> %(range_1)s"
         )
 
+    def test_where_is_null(self):
+        self._test_clause(
+            self.col == None,
+            "data_table.range IS NULL"
+        )
+
+    def test_where_is_not_null(self):
+        self._test_clause(
+            self.col != None,
+            "data_table.range IS NOT NULL"
+        )
+
     def test_where_less_than(self):
         self._test_clause(
             self.col < self._data_str,
@@ -2483,6 +2454,70 @@ class _RangeTypeMixin(object):
             "data_table.range + data_table.range"
         )
 
+    def test_intersection(self):
+        self._test_clause(
+            self.col * self.col,
+            "data_table.range * data_table.range"
+        )
+
+    def test_different(self):
+        self._test_clause(
+            self.col - self.col,
+            "data_table.range - data_table.range"
+        )
+
+
+class _RangeTypeRoundTrip(fixtures.TablesTest):
+    __requires__ = 'range_types', 'psycopg2_compatibility'
+    __backend__ = True
+
+    def extras(self):
+        # done this way so we don't get ImportErrors with
+        # older psycopg2 versions.
+        if testing.against("postgresql+psycopg2cffi"):
+            from psycopg2cffi import extras
+        else:
+            from psycopg2 import extras
+        return extras
+
+    @classmethod
+    def define_tables(cls, metadata):
+        # no reason ranges shouldn't be primary keys,
+        # so lets just use them as such
+        table = Table('data_table', metadata,
+                      Column('range', cls._col_type, primary_key=True),
+                      )
+        cls.col = table.c.range
+
+    def test_actual_type(self):
+        eq_(str(self._col_type()), self._col_str)
+
+    def test_reflect(self):
+        from sqlalchemy import inspect
+        insp = inspect(testing.db)
+        cols = insp.get_columns('data_table')
+        assert isinstance(cols[0]['type'], self._col_type)
+
+    def _assert_data(self):
+        data = testing.db.execute(
+            select([self.tables.data_table.c.range])
+        ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+
+    def test_insert_obj(self):
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_obj()}
+        )
+        self._assert_data()
+
+    def test_insert_text(self):
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        self._assert_data()
+
     def test_union_result(self):
         # insert
         testing.db.engine.execute(
@@ -2496,12 +2531,6 @@ class _RangeTypeMixin(object):
         ).fetchall()
         eq_(data, [(self._data_obj(), )])
 
-    def test_intersection(self):
-        self._test_clause(
-            self.col * self.col,
-            "data_table.range * data_table.range"
-        )
-
     def test_intersection_result(self):
         # insert
         testing.db.engine.execute(
@@ -2515,12 +2544,6 @@ class _RangeTypeMixin(object):
         ).fetchall()
         eq_(data, [(self._data_obj(), )])
 
-    def test_different(self):
-        self._test_clause(
-            self.col - self.col,
-            "data_table.range - data_table.range"
-        )
-
     def test_difference_result(self):
         # insert
         testing.db.engine.execute(
@@ -2535,7 +2558,7 @@ class _RangeTypeMixin(object):
         eq_(data, [(self._data_obj().__class__(empty=True), )])
 
 
-class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
+class _Int4RangeTests(object):
 
     _col_type = INT4RANGE
     _col_str = 'INT4RANGE'
@@ -2545,7 +2568,7 @@ class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
         return self.extras().NumericRange(1, 2)
 
 
-class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
+class _Int8RangeTests(object):
 
     _col_type = INT8RANGE
     _col_str = 'INT8RANGE'
@@ -2557,7 +2580,7 @@ class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
         )
 
 
-class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+class _NumRangeTests(object):
 
     _col_type = NUMRANGE
     _col_str = 'NUMRANGE'
@@ -2569,7 +2592,7 @@ class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
         )
 
 
-class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+class _DateRangeTests(object):
 
     _col_type = DATERANGE
     _col_str = 'DATERANGE'
@@ -2581,7 +2604,7 @@ class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
         )
 
 
-class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+class _DateTimeRangeTests(object):
 
     _col_type = TSRANGE
     _col_str = 'TSRANGE'
@@ -2594,7 +2617,7 @@ class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
         )
 
 
-class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+class _DateTimeTZRangeTests(object):
 
     _col_type = TSTZRANGE
     _col_str = 'TSTZRANGE'
@@ -2620,6 +2643,54 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
         return self.extras().DateTimeTZRange(*self.tstzs())
 
 
+class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation):
+    pass
+
+
+class Int4RangeRoundTripTest(_Int4RangeTests, _RangeTypeRoundTrip):
+    pass
+
+
+class Int8RangeCompilationTest(_Int8RangeTests, _RangeTypeCompilation):
+    pass
+
+
+class Int8RangeRoundTripTest(_Int8RangeTests, _RangeTypeRoundTrip):
+    pass
+
+
+class NumRangeCompilationTest(_NumRangeTests, _RangeTypeCompilation):
+    pass
+
+
+class NumRangeRoundTripTest(_NumRangeTests, _RangeTypeRoundTrip):
+    pass
+
+
+class DateRangeCompilationTest(_DateRangeTests, _RangeTypeCompilation):
+    pass
+
+
+class DateRangeRoundTripTest(_DateRangeTests, _RangeTypeRoundTrip):
+    pass
+
+
+class DateTimeRangeCompilationTest(_DateTimeRangeTests, _RangeTypeCompilation):
+    pass
+
+
+class DateTimeRangeRoundTripTest(_DateTimeRangeTests, _RangeTypeRoundTrip):
+    pass
+
+
+class DateTimeTZRangeCompilationTest(_DateTimeTZRangeTests, _RangeTypeCompilation):
+    pass
+
+
+class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip):
+    pass
+
+
 class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
     __dialect__ = 'postgresql'