]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Merge pull request #5 from cjw296/pg-ranges
authormike bayer <mike_mp@zzzcomputing.com>
Sat, 22 Jun 2013 14:47:02 +0000 (07:47 -0700)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Jun 2013 15:27:09 +0000 (11:27 -0400)
Support for Postgres range types.

doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/constraints.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py [new file with mode: 0644]
test/dialect/test_postgresql.py
test/requirements.py

index df141cce00cc7acc7547dde628a7641c93f503d2..f2b401e9a57176695a8640fd71e86080fab1cd24 100644 (file)
@@ -16,7 +16,8 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
         ARRAY, BIGINT, BIT, BOOLEAN, BYTEA, CHAR, CIDR, DATE, \
         DOUBLE_PRECISION, ENUM, FLOAT, HSTORE, INET, INTEGER, \
         INTERVAL, MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, \
-        TIMESTAMP, UUID, VARCHAR
+        TIMESTAMP, UUID, VARCHAR, INT4RANGE, INT8RANGE, NUMRANGE, \
+        DATERANGE, TSRANGE, TSTZRANGE
 
 Types which are specific to PostgreSQL, or have PostgreSQL-specific
 construction arguments, are as follows:
@@ -81,6 +82,54 @@ construction arguments, are as follows:
     :members: __init__
     :show-inheritance:
 
+.. autoclass:: sqlalchemy.dialects.postgresql.ranges.RangeOperators
+    :members:
+
+.. autoclass:: INT4RANGE
+   :show-inheritance:
+
+.. autoclass:: INT8RANGE
+   :show-inheritance:
+
+.. autoclass:: NUMRANGE
+   :show-inheritance:
+
+.. autoclass:: DATERANGE
+   :show-inheritance:
+
+.. autoclass:: TSRANGE
+   :show-inheritance:
+
+.. autoclass:: TSTZRANGE
+   :show-inheritance:
+
+
+PostgreSQL Constraint Types
+---------------------------
+
+SQLAlchemy supports Postgresql EXCLUDE constraints via the
+:class:`ExcludeConstraint` class:
+
+.. autoclass:: ExcludeConstraint
+   :show-inheritance:
+   :members: __init__
+
+For example::
+
+  from sqlalchemy.dialects.postgresql import (
+      ExcludeConstraint,
+      TSRANGE as Range,
+      )
+
+  class RoomBookings(Base):
+
+      room = Column(Integer(), primary_key=True)
+      during = Column(TSRANGE())
+
+      __table_args__ = (
+          ExcludeConstraint(('room', '='), ('during', '&&')),
+      )
+
 psycopg2
 --------------
 
index d0f785bdda314bc9d4f39e07a301584056d6ab29..408b678467ded92201554db7dc0013f8f9b01e2e 100644 (file)
@@ -12,12 +12,16 @@ from .base import \
     INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
     INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
     DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
+from .constraints import ExcludeConstraint
 from .hstore import HSTORE, hstore
+from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
+    TSTZRANGE
 
 __all__ = (
     'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
     'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR',
     'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
     'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE',
-    'hstore'
+    'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
+    'TSRANGE', 'TSTZRANGE'
 )
index 8c9af5bcf6810259782492b897ad67e0126c6a6d..e1fc9e21116f50e97df2434edac8d1eaefdd07c3 100644 (file)
@@ -444,7 +444,7 @@ class array(expression.Tuple):
 
     An instance of :class:`.array` will always have the datatype
     :class:`.ARRAY`.  The "inner" type of the array is inferred from
-    the values present, unless the "type_" keyword argument is passed::
+    the values present, unless the ``type_`` keyword argument is passed::
 
         array(['foo', 'bar'], type_=CHAR)
 
@@ -1142,6 +1142,22 @@ class PGDDLCompiler(compiler.DDLCompiler):
             text += " WHERE " + where_compiled
         return text
 
+    def visit_exclude_constraint(self, constraint):
+        text = ""
+        if constraint.name is not None:
+            text += "CONSTRAINT %s " % \
+                    self.preparer.format_constraint(constraint)
+        elements = []
+        for c in constraint.columns:
+            op = constraint.operators[c.name]
+            elements.append(self.preparer.quote(c.name, c.quote)+' WITH '+op)
+        text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements))
+        if constraint.where is not None:
+            sqltext = sql_util.expression_as_ddl(constraint.where)
+            text += ' WHERE (%s)' % self.sql_compiler.process(sqltext)
+        text += self.define_constraint_deferrability(constraint)
+        return text
+
 
 class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_INET(self, type_):
@@ -1168,6 +1184,24 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_HSTORE(self, type_):
         return "HSTORE"
 
+    def visit_INT4RANGE(self, type_):
+        return "INT4RANGE"
+
+    def visit_INT8RANGE(self, type_):
+        return "INT8RANGE"
+
+    def visit_NUMRANGE(self, type_):
+        return "NUMRANGE"
+
+    def visit_DATERANGE(self, type_):
+        return "DATERANGE"
+
+    def visit_TSRANGE(self, type_):
+        return "TSRANGE"
+
+    def visit_TSTZRANGE(self, type_):
+        return "TSTZRANGE"
+
     def visit_datetime(self, type_):
         return self.visit_TIMESTAMP(type_)
 
diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/constraints.py
new file mode 100644 (file)
index 0000000..5b8bbe6
--- /dev/null
@@ -0,0 +1,73 @@
+# Copyright (C) 2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+from sqlalchemy.schema import ColumnCollectionConstraint
+from sqlalchemy.sql import expression
+
+class ExcludeConstraint(ColumnCollectionConstraint):
+    """A table-level EXCLUDE constraint.
+
+    Defines an EXCLUDE constraint as described in the `postgres
+    documentation`__.
+
+    __ http://www.postgresql.org/docs/9.0/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
+    """
+
+    __visit_name__ = 'exclude_constraint'
+
+    where = None
+
+    def __init__(self, *elements, **kw):
+        """
+        :param \*elements:
+          A sequence of two tuples of the form ``(column, operator)`` where
+          column must be a column name or Column object and operator must
+          be a string containing the operator to use.
+
+        :param name:
+          Optional, the in-database name of this constraint.
+
+        :param deferrable:
+          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
+          issuing DDL for this constraint.
+
+        :param initially:
+          Optional string.  If set, emit INITIALLY <value> when issuing DDL
+          for this constraint.
+
+        :param using:
+          Optional string.  If set, emit USING <index_method> when issuing DDL
+          for this constraint. Defaults to 'gist'.
+          
+        :param where:
+          Optional string.  If set, emit WHERE <predicate> when issuing DDL
+          for this constraint.
+
+        """
+        ColumnCollectionConstraint.__init__(
+            self,
+            *[col for col, op in elements],
+            name=kw.get('name'),
+            deferrable=kw.get('deferrable'),
+            initially=kw.get('initially')
+            )
+        self.operators = {}
+        for col_or_string, op in elements:
+            name = getattr(col_or_string, 'name', col_or_string)
+            self.operators[name] = op
+        self.using = kw.get('using', 'gist')
+        where = kw.get('where')
+        if where:
+            self.where =  expression._literal_as_text(where)
+            
+    def copy(self, **kw):
+        elements = [(col, self.operators[col])
+                    for col in self.columns.keys()]
+        c = self.__class__(*elements,
+                            name=self.name,
+                            deferrable=self.deferrable,
+                            initially=self.initially)
+        c.dispatch._update(self.dispatch)
+        return c
+
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
new file mode 100644 (file)
index 0000000..e7ab1d5
--- /dev/null
@@ -0,0 +1,133 @@
+# Copyright (C) 2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from .base import ischema_names
+from ... import types as sqltypes
+
+__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
+
+class RangeOperators(object):
+    """
+    This mixin provides functionality for the Range Operators
+    listed in Table 9-44 of the `postgres documentation`__ for Range
+    Functions and Operators. It is used by all the range types
+    provided in the ``postgres`` dialect and can likely be used for
+    any range types you create yourself.
+
+    __ http://www.postgresql.org/docs/devel/static/functions-range.html
+
+    No extra support is provided for the Range Functions listed in
+    Table 9-45 of the postgres documentation. For these, the normal
+    :func:`~sqlalchemy.sql.expression.func` object should be used.
+    """
+
+    class comparator_factory(sqltypes.Concatenable.Comparator):
+        """Define comparison operations for range types."""
+
+        def __ne__(self, other):
+            "Boolean expression. Returns true if two ranges are not equal" 
+            return self.expr.op('<>')(other)
+
+        def contains(self, other, **kw):
+            """Boolean expression. Returns true if the right hand operand,
+            which can be an element or a range, is contained within the
+            column.
+            """
+            return self.expr.op('@>')(other)
+
+        def contained_by(self, other):
+            """Boolean expression. Returns true if the column is contained
+            within the right hand operand.
+            """
+            return self.expr.op('<@')(other)
+
+        def overlaps(self, other):
+            """Boolean expression. Returns true if the column overlaps
+            (has points in common with) the right hand operand.
+            """
+            return self.expr.op('&&')(other)
+
+        def strictly_left_of(self, other):
+            """Boolean expression. Returns true if the column is strictly
+            left of the right hand operand.
+            """
+            return self.expr.op('<<')(other)
+            
+        __lshift__ = strictly_left_of
+
+        def strictly_right_of(self, other):
+            """Boolean expression. Returns true if the column is strictly
+            right of the right hand operand.
+            """
+            return self.expr.op('>>')(other)
+            
+        __rshift__ = strictly_right_of
+        
+        def not_extend_right_of(self, other):
+            """Boolean expression. Returns true if the range in the column
+            does not extend right of the range in the operand.
+            """
+            return self.expr.op('&<')(other)
+        
+        def not_extend_left_of(self, other):
+            """Boolean expression. Returns true if the range in the column
+            does not extend left of the range in the operand.
+            """
+            return self.expr.op('&>')(other)
+
+        def adjacent_to(self, other):
+            """Boolean expression. Returns true if the range in the column
+            is adjacent to the range in the operand.
+            """
+            return self.expr.op('-|-')(other)
+
+        def __add__(self, other):
+            """Range expression. Returns the union of the two ranges.
+            Will raise an exception if the resulting range is not
+            contigous.
+            """
+            return self.expr.op('+')(other)
+
+class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
+    "Represent the Postgresql INT4RANGE type."
+    
+    __visit_name__ = 'INT4RANGE'
+
+ischema_names['int4range'] = INT4RANGE
+
+class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
+    "Represent the Postgresql INT8RANGE type."
+    
+    __visit_name__ = 'INT8RANGE'
+
+ischema_names['int8range'] = INT8RANGE
+
+class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
+    "Represent the Postgresql NUMRANGE type."
+    
+    __visit_name__ = 'NUMRANGE'
+
+ischema_names['numrange'] = NUMRANGE
+
+class DATERANGE(RangeOperators, sqltypes.TypeEngine):
+    "Represent the Postgresql DATERANGE type."
+    
+    __visit_name__ = 'DATERANGE'
+
+ischema_names['daterange'] = DATERANGE
+
+class TSRANGE(RangeOperators, sqltypes.TypeEngine):
+    "Represent the Postgresql TSRANGE type."
+    
+    __visit_name__ = 'TSRANGE'
+
+ischema_names['tsrange'] = TSRANGE
+
+class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
+    "Represent the Postgresql TSTZRANGE type."
+    
+    __visit_name__ = 'TSTZRANGE'
+
+ischema_names['tstzrange'] = TSTZRANGE
index 43c4719190bc3e3587e53ed8d7758a7d45aeae49..1389fe5f877cd40abc59739e169a538a3e8da0f7 100644 (file)
@@ -17,7 +17,9 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
 from sqlalchemy.orm import Session, mapper, aliased
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
-from sqlalchemy.dialects.postgresql import HSTORE, hstore, array
+from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
+            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
+            ExcludeConstraint
 import decimal
 from sqlalchemy import util
 from sqlalchemy.testing.util import round_decimal
@@ -182,6 +184,53 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             'USING hash (data)',
                             dialect=postgresql.dialect())
 
+    def test_exclude_constraint_min(self):
+        m = MetaData()
+        tbl = Table('testtbl', m, 
+                    Column('room', Integer, primary_key=True))
+        cons = ExcludeConstraint(('room', '='))
+        tbl.append_constraint(cons)
+        self.assert_compile(schema.AddConstraint(cons),
+                            'ALTER TABLE testtbl ADD EXCLUDE USING gist '
+                            '(room WITH =)',
+                            dialect=postgresql.dialect())
+
+    def test_exclude_constraint_full(self):
+        m = MetaData()
+        room = Column('room', Integer, primary_key=True)
+        tbl = Table('testtbl', m,
+                    room,
+                    Column('during', TSRANGE))
+        room = Column('room', Integer, primary_key=True)
+        cons = ExcludeConstraint((room, '='), ('during', '&&'),
+                                 name='my_name',
+                                 using='gist',
+                                 where="room > 100",
+                                 deferrable=True,
+                                 initially='immediate')
+        tbl.append_constraint(cons)
+        self.assert_compile(schema.AddConstraint(cons),
+                            'ALTER TABLE testtbl ADD CONSTRAINT my_name '
+                            'EXCLUDE USING gist '
+                            '(room WITH =, during WITH ''&&) WHERE '
+                            '(room > 100) DEFERRABLE INITIALLY immediate',
+                            dialect=postgresql.dialect())
+
+    def test_exclude_constraint_copy(self):
+        m = MetaData()
+        cons = ExcludeConstraint(('room', '='))
+        tbl = Table('testtbl', m, 
+              Column('room', Integer, primary_key=True),
+              cons)
+        # apparently you can't copy a ColumnCollectionConstraint until
+        # after it has been bound to a table...
+        cons_copy = cons.copy()
+        tbl.append_constraint(cons_copy)
+        self.assert_compile(schema.AddConstraint(cons_copy),
+                            'ALTER TABLE testtbl ADD EXCLUDE USING gist '
+                            '(room WITH =)',
+                            dialect=postgresql.dialect())
+
     def test_substring(self):
         self.assert_compile(func.substring('abc', 1, 2),
                             'SUBSTRING(%(substring_1)s FROM %(substring_2)s '
@@ -3241,3 +3290,282 @@ class HStoreRoundTripTest(fixtures.TablesTest):
     def test_unicode_round_trip_native(self):
         engine = testing.db
         self._test_unicode_round_trip(engine)
+
+class _RangeTypeMixin(object):
+    __requires__ = 'range_types',
+    __dialect__ = 'postgresql+psycopg2'
+
+    @property
+    def extras(self):
+        # done this way so we don't get ImportErrors with
+        # older psycopg2 versions.
+        from psycopg2 import extras
+        return extras
+    
+    @classmethod
+    def define_tables(cls, metadata):
+        # no reason ranges shouldn't be primary keys,
+        # so lets just use them as such
+        table = Table('data_table', metadata,
+            Column('range', cls._col_type, primary_key=True),
+        )
+        cls.col = table.c.range
+
+    def test_actual_type(self):
+        eq_(str(self._col_type()), self._col_str)
+        
+    def test_reflect(self):
+        from sqlalchemy import inspect
+        insp = inspect(testing.db)
+        cols = insp.get_columns('data_table')
+        assert isinstance(cols[0]['type'], self._col_type)
+
+    def _assert_data(self):
+        data = testing.db.execute(
+            select([self.tables.data_table.c.range])
+        ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+
+    def test_insert_obj(self):
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_obj()}
+        )
+        self._assert_data()
+
+    def test_insert_text(self):
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        self._assert_data()
+
+    # operator tests
+        
+    def _test_clause(self, colclause, expected):
+        dialect = postgresql.dialect()
+        compiled = str(colclause.compile(dialect=dialect))
+        eq_(compiled, expected)
+
+    def test_where_equal(self):
+        self._test_clause(
+            self.col==self._data_str,
+            "data_table.range = %(range_1)s"
+        )
+
+    def test_where_not_equal(self):
+        self._test_clause(
+            self.col!=self._data_str,
+            "data_table.range <> %(range_1)s"
+        )
+
+    def test_where_less_than(self):
+        self._test_clause(
+            self.col < self._data_str,
+            "data_table.range < %(range_1)s"
+        )
+
+    def test_where_greater_than(self):
+        self._test_clause(
+            self.col > self._data_str,
+            "data_table.range > %(range_1)s"
+        )
+
+    def test_where_less_than_or_equal(self):
+        self._test_clause(
+            self.col <= self._data_str,
+            "data_table.range <= %(range_1)s"
+        )
+
+    def test_where_greater_than_or_equal(self):
+        self._test_clause(
+            self.col >= self._data_str,
+            "data_table.range >= %(range_1)s"
+        )
+
+    def test_contains(self):
+        self._test_clause(
+            self.col.contains(self._data_str),
+            "data_table.range @> %(range_1)s"
+        )
+
+    def test_contained_by(self):
+        self._test_clause(
+            self.col.contained_by(self._data_str),
+            "data_table.range <@ %(range_1)s"
+        )
+
+    def test_overlaps(self):
+        self._test_clause(
+            self.col.overlaps(self._data_str),
+            "data_table.range && %(range_1)s"
+        )
+
+    def test_strictly_left_of(self):
+        self._test_clause(
+            self.col << self._data_str,
+            "data_table.range << %(range_1)s"
+        )
+        self._test_clause(
+            self.col.strictly_left_of(self._data_str),
+            "data_table.range << %(range_1)s"
+        )
+
+    def test_strictly_right_of(self):
+        self._test_clause(
+            self.col >> self._data_str,
+            "data_table.range >> %(range_1)s"
+        )
+        self._test_clause(
+            self.col.strictly_right_of(self._data_str),
+            "data_table.range >> %(range_1)s"
+        )
+
+    def test_not_extend_right_of(self):
+        self._test_clause(
+            self.col.not_extend_right_of(self._data_str),
+            "data_table.range &< %(range_1)s"
+        )
+
+    def test_not_extend_left_of(self):
+        self._test_clause(
+            self.col.not_extend_left_of(self._data_str),
+            "data_table.range &> %(range_1)s"
+        )
+
+    def test_adjacent_to(self):
+        self._test_clause(
+            self.col.adjacent_to(self._data_str),
+            "data_table.range -|- %(range_1)s"
+        )
+
+    def test_union(self):
+        self._test_clause(
+            self.col + self.col,
+            "data_table.range + data_table.range"
+        )
+
+    def test_union_result(self):
+        # insert
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        # select
+        range = self.tables.data_table.c.range
+        data = testing.db.execute(
+            select([range + range])
+            ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+        
+
+    def test_intersection(self):
+        self._test_clause(
+            self.col * self.col,
+            "data_table.range * data_table.range"
+        )
+
+    def test_intersection_result(self):
+        # insert
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        # select
+        range = self.tables.data_table.c.range
+        data = testing.db.execute(
+            select([range * range])
+            ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+        
+    def test_different(self):
+        self._test_clause(
+            self.col - self.col,
+            "data_table.range - data_table.range"
+        )
+
+    def test_difference_result(self):
+        # insert
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        # select
+        range = self.tables.data_table.c.range
+        data = testing.db.execute(
+            select([range - range])
+            ).fetchall()
+        eq_(data, [(self._data_obj().__class__(empty=True), )])
+        
+class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = INT4RANGE
+    _col_str = 'INT4RANGE'
+    _data_str = '[1,2)'
+    def _data_obj(self):
+        return self.extras.NumericRange(1, 2)
+
+class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = INT8RANGE
+    _col_str = 'INT8RANGE'
+    _data_str = '[9223372036854775806,9223372036854775807)'
+    def _data_obj(self):
+        return self.extras.NumericRange(
+            9223372036854775806, 9223372036854775807
+            )
+
+class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = NUMRANGE
+    _col_str = 'NUMRANGE'
+    _data_str = '[1.0,2.0)'
+    def _data_obj(self):
+        return self.extras.NumericRange(
+            decimal.Decimal('1.0'), decimal.Decimal('2.0')
+            )
+
+class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = DATERANGE
+    _col_str = 'DATERANGE'
+    _data_str = '[2013-03-23,2013-03-24)'
+    def _data_obj(self):
+        return self.extras.DateRange(
+            datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
+            )
+
+class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = TSRANGE
+    _col_str = 'TSRANGE'
+    _data_str = '[2013-03-23 14:30,2013-03-23 23:30)'
+    def _data_obj(self):
+        return self.extras.DateTimeRange(
+            datetime.datetime(2013, 3, 23, 14, 30),
+            datetime.datetime(2013, 3, 23, 23, 30)
+            )
+
+class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = TSTZRANGE
+    _col_str = 'TSTZRANGE'
+
+    # make sure we use one, steady timestamp with timezone pair
+    # for all parts of all these tests
+    _tstzs = None
+    def tstzs(self):
+        if self._tstzs is None:
+            lower = testing.db.connect().scalar(
+                func.current_timestamp().select()
+                )
+            upper = lower+datetime.timedelta(1)
+            self._tstzs = (lower, upper)
+        return self._tstzs
+
+    @property
+    def _data_str(self):
+        return '[%s,%s)' % self.tstzs()
+    
+    def _data_obj(self):
+        return self.extras.DateTimeTZRange(*self.tstzs())
index 42128669d8d69dcb38cb718115217af77893f87e..807f21baca58ebd3f1e8f5d8688e93cac89b3ea5 100644 (file)
@@ -585,6 +585,21 @@ class DefaultRequirements(SuiteRequirements):
 
         return only_if(check_hstore)
 
+    @property
+    def range_types(self):
+        def check_range_types():
+            if not against("postgresql+psycopg2"):
+                return False
+            try:
+                self.db.execute("select '[1,2)'::int4range;")
+                # only supported in psycopg 2.5+
+                from psycopg2.extras import NumericRange
+                return True
+            except:
+                return False
+
+        return only_if(check_range_types)
+
     @property
     def sqlite(self):
         return skip_if(lambda: not self._has_sqlite())