From f4020282b798ea510e6aafda779ab33c692c0120 Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Sun, 19 May 2013 15:20:57 +0100 Subject: [PATCH] add support for range operators listed in http://www.postgresql.org/docs/9.2/interactive/functions-range.html --- lib/sqlalchemy/dialects/postgresql/ranges.py | 81 +++++++++- test/dialect/test_postgresql.py | 160 ++++++++++++++++++- 2 files changed, 234 insertions(+), 7 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index b3a670d919..2054ef1374 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -8,42 +8,111 @@ from ... import types as sqltypes __all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE') -class INT4RANGE(sqltypes.TypeEngine): +class RangeOperators(object): + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for range types.""" + + def __ne__(self, other): + "Boolean expression. Returns true if two ranges are not equal" + return self.expr.op('<>')(other) + + def contains(self, other, **kw): + """Boolean expression. Returns true if the right hand operand, + which can be an element or a range, is contained within the + column. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Returns true if the column is contained + within the right hand operand. + """ + return self.expr.op('<@')(other) + + def overlaps(self, other): + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + return self.expr.op('&&')(other) + + def strictly_left_of(self, other): + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + return self.expr.op('<<')(other) + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other): + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + return self.expr.op('>>')(other) + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + return self.expr.op('&<')(other) + + def not_extend_left_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + return self.expr.op('&>')(other) + + def adjacent_to(self, other): + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + return self.expr.op('-|-')(other) + + def __add__(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contigous. + """ + return self.expr.op('+')(other) + +class INT4RANGE(RangeOperators, sqltypes.TypeEngine): "Represent the Postgresql INT4RANGE type." __visit_name__ = 'INT4RANGE' ischema_names['int4range'] = INT4RANGE -class INT8RANGE(sqltypes.TypeEngine): +class INT8RANGE(RangeOperators, sqltypes.TypeEngine): "Represent the Postgresql INT8RANGE type." __visit_name__ = 'INT8RANGE' ischema_names['int8range'] = INT8RANGE -class NUMRANGE(sqltypes.TypeEngine): +class NUMRANGE(RangeOperators, sqltypes.TypeEngine): "Represent the Postgresql NUMRANGE type." __visit_name__ = 'NUMRANGE' ischema_names['numrange'] = NUMRANGE -class DATERANGE(sqltypes.TypeEngine): +class DATERANGE(RangeOperators, sqltypes.TypeEngine): "Represent the Postgresql DATERANGE type." __visit_name__ = 'DATERANGE' ischema_names['daterange'] = DATERANGE -class TSRANGE(sqltypes.TypeEngine): +class TSRANGE(RangeOperators, sqltypes.TypeEngine): "Represent the Postgresql TSRANGE type." __visit_name__ = 'TSRANGE' ischema_names['tsrange'] = TSRANGE -class TSTZRANGE(sqltypes.TypeEngine): +class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): "Represent the Postgresql TSTZRANGE type." __visit_name__ = 'TSTZRANGE' diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index de37ffd7dc..70b683b088 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -3249,9 +3249,10 @@ class _RangeTypeMixin(object): def define_tables(cls, metadata): # no reason ranges shouldn't be primary keys, # so lets just use them as such - Table('data_table', metadata, + table = Table('data_table', metadata, Column('range', cls._col_type, primary_key=True), ) + cls.col = table.c.range def test_actual_type(self): eq_(str(self._col_type()), self._col_str) @@ -3282,6 +3283,163 @@ class _RangeTypeMixin(object): ) self._assert_data() + # operator tests + + def _test_clause(self, colclause, expected): + dialect = postgresql.dialect() + compiled = str(colclause.compile(dialect=dialect)) + eq_(compiled, expected) + + def test_where_equal(self): + self._test_clause( + self.col==self._data_str, + "data_table.range = %(range_1)s" + ) + + def test_where_not_equal(self): + self._test_clause( + self.col!=self._data_str, + "data_table.range <> %(range_1)s" + ) + + def test_where_less_than(self): + self._test_clause( + self.col < self._data_str, + "data_table.range < %(range_1)s" + ) + + def test_where_greater_than(self): + self._test_clause( + self.col > self._data_str, + "data_table.range > %(range_1)s" + ) + + def test_where_less_than_or_equal(self): + self._test_clause( + self.col <= self._data_str, + "data_table.range <= %(range_1)s" + ) + + def test_where_greater_than_or_equal(self): + self._test_clause( + self.col >= self._data_str, + "data_table.range >= %(range_1)s" + ) + + def test_contains(self): + self._test_clause( + self.col.contains(self._data_str), + "data_table.range @> %(range_1)s" + ) + + def test_contained_by(self): + self._test_clause( + self.col.contained_by(self._data_str), + "data_table.range <@ %(range_1)s" + ) + + def test_overlaps(self): + self._test_clause( + self.col.overlaps(self._data_str), + "data_table.range && %(range_1)s" + ) + + def test_strictly_left_of(self): + self._test_clause( + self.col << self._data_str, + "data_table.range << %(range_1)s" + ) + self._test_clause( + self.col.strictly_left_of(self._data_str), + "data_table.range << %(range_1)s" + ) + + def test_strictly_right_of(self): + self._test_clause( + self.col >> self._data_str, + "data_table.range >> %(range_1)s" + ) + self._test_clause( + self.col.strictly_right_of(self._data_str), + "data_table.range >> %(range_1)s" + ) + + def test_not_extend_right_of(self): + self._test_clause( + self.col.not_extend_right_of(self._data_str), + "data_table.range &< %(range_1)s" + ) + + def test_not_extend_left_of(self): + self._test_clause( + self.col.not_extend_left_of(self._data_str), + "data_table.range &> %(range_1)s" + ) + + def test_adjacent_to(self): + self._test_clause( + self.col.adjacent_to(self._data_str), + "data_table.range -|- %(range_1)s" + ) + + def test_union(self): + self._test_clause( + self.col + self.col, + "data_table.range + data_table.range" + ) + + def test_union_result(self): + # insert + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + # select + range = self.tables.data_table.c.range + data = testing.db.execute( + select([range + range]) + ).fetchall() + eq_(data, [(self._data_obj(), )]) + + + def test_intersection(self): + self._test_clause( + self.col * self.col, + "data_table.range * data_table.range" + ) + + def test_intersection_result(self): + # insert + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + # select + range = self.tables.data_table.c.range + data = testing.db.execute( + select([range * range]) + ).fetchall() + eq_(data, [(self._data_obj(), )]) + + def test_different(self): + self._test_clause( + self.col - self.col, + "data_table.range - data_table.range" + ) + + def test_difference_result(self): + # insert + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + # select + range = self.tables.data_table.c.range + data = testing.db.execute( + select([range - range]) + ).fetchall() + eq_(data, [(self._data_obj().__class__(empty=True), )]) + class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest): _col_type = INT4RANGE -- 2.47.3