]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add support for range operators listed in http://www.postgresql.org/docs/9.2/interact...
authorChris Withers <chris@simplistix.co.uk>
Sun, 19 May 2013 14:20:57 +0000 (15:20 +0100)
committerChris Withers <chris@simplistix.co.uk>
Mon, 10 Jun 2013 11:09:55 +0000 (12:09 +0100)
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/test_postgresql.py

index b3a670d919c263afa1e77504146eef719a578cf1..2054ef137429624ab44ad50fb564978b82ba23e0 100644 (file)
@@ -8,42 +8,111 @@ from ... import types as sqltypes
 
 __all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
 
-class INT4RANGE(sqltypes.TypeEngine):
+class RangeOperators(object):
+
+    class comparator_factory(sqltypes.Concatenable.Comparator):
+        """Define comparison operations for range types."""
+
+        def __ne__(self, other):
+            "Boolean expression. Returns true if two ranges are not equal" 
+            return self.expr.op('<>')(other)
+
+        def contains(self, other, **kw):
+            """Boolean expression. Returns true if the right hand operand,
+            which can be an element or a range, is contained within the
+            column.
+            """
+            return self.expr.op('@>')(other)
+
+        def contained_by(self, other):
+            """Boolean expression. Returns true if the column is contained
+            within the right hand operand.
+            """
+            return self.expr.op('<@')(other)
+
+        def overlaps(self, other):
+            """Boolean expression. Returns true if the column overlaps
+            (has points in common with) the right hand operand.
+            """
+            return self.expr.op('&&')(other)
+
+        def strictly_left_of(self, other):
+            """Boolean expression. Returns true if the column is strictly
+            left of the right hand operand.
+            """
+            return self.expr.op('<<')(other)
+            
+        __lshift__ = strictly_left_of
+
+        def strictly_right_of(self, other):
+            """Boolean expression. Returns true if the column is strictly
+            right of the right hand operand.
+            """
+            return self.expr.op('>>')(other)
+            
+        __rshift__ = strictly_right_of
+        
+        def not_extend_right_of(self, other):
+            """Boolean expression. Returns true if the range in the column
+            does not extend right of the range in the operand.
+            """
+            return self.expr.op('&<')(other)
+        
+        def not_extend_left_of(self, other):
+            """Boolean expression. Returns true if the range in the column
+            does not extend left of the range in the operand.
+            """
+            return self.expr.op('&>')(other)
+
+        def adjacent_to(self, other):
+            """Boolean expression. Returns true if the range in the column
+            is adjacent to the range in the operand.
+            """
+            return self.expr.op('-|-')(other)
+
+        def __add__(self, other):
+            """Range expression. Returns the union of the two ranges.
+            Will raise an exception if the resulting range is not
+            contigous.
+            """
+            return self.expr.op('+')(other)
+
+class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
     "Represent the Postgresql INT4RANGE type."
     
     __visit_name__ = 'INT4RANGE'
 
 ischema_names['int4range'] = INT4RANGE
 
-class INT8RANGE(sqltypes.TypeEngine):
+class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
     "Represent the Postgresql INT8RANGE type."
     
     __visit_name__ = 'INT8RANGE'
 
 ischema_names['int8range'] = INT8RANGE
 
-class NUMRANGE(sqltypes.TypeEngine):
+class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
     "Represent the Postgresql NUMRANGE type."
     
     __visit_name__ = 'NUMRANGE'
 
 ischema_names['numrange'] = NUMRANGE
 
-class DATERANGE(sqltypes.TypeEngine):
+class DATERANGE(RangeOperators, sqltypes.TypeEngine):
     "Represent the Postgresql DATERANGE type."
     
     __visit_name__ = 'DATERANGE'
 
 ischema_names['daterange'] = DATERANGE
 
-class TSRANGE(sqltypes.TypeEngine):
+class TSRANGE(RangeOperators, sqltypes.TypeEngine):
     "Represent the Postgresql TSRANGE type."
     
     __visit_name__ = 'TSRANGE'
 
 ischema_names['tsrange'] = TSRANGE
 
-class TSTZRANGE(sqltypes.TypeEngine):
+class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
     "Represent the Postgresql TSTZRANGE type."
     
     __visit_name__ = 'TSTZRANGE'
index de37ffd7dc63ae367dc58304ef0e70f941c03f69..70b683b088200c17bcb2539ee69ab8e9de80de3c 100644 (file)
@@ -3249,9 +3249,10 @@ class _RangeTypeMixin(object):
     def define_tables(cls, metadata):
         # no reason ranges shouldn't be primary keys,
         # so lets just use them as such
-        Table('data_table', metadata,
+        table = Table('data_table', metadata,
             Column('range', cls._col_type, primary_key=True),
         )
+        cls.col = table.c.range
 
     def test_actual_type(self):
         eq_(str(self._col_type()), self._col_str)
@@ -3282,6 +3283,163 @@ class _RangeTypeMixin(object):
         )
         self._assert_data()
 
+    # operator tests
+        
+    def _test_clause(self, colclause, expected):
+        dialect = postgresql.dialect()
+        compiled = str(colclause.compile(dialect=dialect))
+        eq_(compiled, expected)
+
+    def test_where_equal(self):
+        self._test_clause(
+            self.col==self._data_str,
+            "data_table.range = %(range_1)s"
+        )
+
+    def test_where_not_equal(self):
+        self._test_clause(
+            self.col!=self._data_str,
+            "data_table.range <> %(range_1)s"
+        )
+
+    def test_where_less_than(self):
+        self._test_clause(
+            self.col < self._data_str,
+            "data_table.range < %(range_1)s"
+        )
+
+    def test_where_greater_than(self):
+        self._test_clause(
+            self.col > self._data_str,
+            "data_table.range > %(range_1)s"
+        )
+
+    def test_where_less_than_or_equal(self):
+        self._test_clause(
+            self.col <= self._data_str,
+            "data_table.range <= %(range_1)s"
+        )
+
+    def test_where_greater_than_or_equal(self):
+        self._test_clause(
+            self.col >= self._data_str,
+            "data_table.range >= %(range_1)s"
+        )
+
+    def test_contains(self):
+        self._test_clause(
+            self.col.contains(self._data_str),
+            "data_table.range @> %(range_1)s"
+        )
+
+    def test_contained_by(self):
+        self._test_clause(
+            self.col.contained_by(self._data_str),
+            "data_table.range <@ %(range_1)s"
+        )
+
+    def test_overlaps(self):
+        self._test_clause(
+            self.col.overlaps(self._data_str),
+            "data_table.range && %(range_1)s"
+        )
+
+    def test_strictly_left_of(self):
+        self._test_clause(
+            self.col << self._data_str,
+            "data_table.range << %(range_1)s"
+        )
+        self._test_clause(
+            self.col.strictly_left_of(self._data_str),
+            "data_table.range << %(range_1)s"
+        )
+
+    def test_strictly_right_of(self):
+        self._test_clause(
+            self.col >> self._data_str,
+            "data_table.range >> %(range_1)s"
+        )
+        self._test_clause(
+            self.col.strictly_right_of(self._data_str),
+            "data_table.range >> %(range_1)s"
+        )
+
+    def test_not_extend_right_of(self):
+        self._test_clause(
+            self.col.not_extend_right_of(self._data_str),
+            "data_table.range &< %(range_1)s"
+        )
+
+    def test_not_extend_left_of(self):
+        self._test_clause(
+            self.col.not_extend_left_of(self._data_str),
+            "data_table.range &> %(range_1)s"
+        )
+
+    def test_adjacent_to(self):
+        self._test_clause(
+            self.col.adjacent_to(self._data_str),
+            "data_table.range -|- %(range_1)s"
+        )
+
+    def test_union(self):
+        self._test_clause(
+            self.col + self.col,
+            "data_table.range + data_table.range"
+        )
+
+    def test_union_result(self):
+        # insert
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        # select
+        range = self.tables.data_table.c.range
+        data = testing.db.execute(
+            select([range + range])
+            ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+        
+
+    def test_intersection(self):
+        self._test_clause(
+            self.col * self.col,
+            "data_table.range * data_table.range"
+        )
+
+    def test_intersection_result(self):
+        # insert
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        # select
+        range = self.tables.data_table.c.range
+        data = testing.db.execute(
+            select([range * range])
+            ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+        
+    def test_different(self):
+        self._test_clause(
+            self.col - self.col,
+            "data_table.range - data_table.range"
+        )
+
+    def test_difference_result(self):
+        # insert
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        # select
+        range = self.tables.data_table.c.range
+        data = testing.db.execute(
+            select([range - range])
+            ).fetchall()
+        eq_(data, [(self._data_obj().__class__(empty=True), )])
+        
 class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
 
     _col_type = INT4RANGE