]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Basic type support for the new range types in postgres 9.2
authorChris Withers <chris@simplistix.co.uk>
Sun, 19 May 2013 07:50:06 +0000 (08:50 +0100)
committerChris Withers <chris@simplistix.co.uk>
Mon, 10 Jun 2013 11:09:55 +0000 (12:09 +0100)
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/ranges.py [new file with mode: 0644]
test/dialect/test_postgresql.py
test/requirements.py

index d0f785bdda314bc9d4f39e07a301584056d6ab29..3c259671d94558c75d979889dbeeba449bedbf38 100644 (file)
@@ -13,11 +13,14 @@ from .base import \
     INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
     DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
 from .hstore import HSTORE, hstore
+from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
+    TSTZRANGE
 
 __all__ = (
     'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
     'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR',
     'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
     'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE',
-    'hstore'
+    'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
+    'TSRANGE', 'TSTZRANGE'
 )
index 0810e03849ef6bd41e1b36622bdb92d7d887fa7f..127e1130b193191dab176c7b4daac13710f4c380 100644 (file)
@@ -1150,6 +1150,24 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_HSTORE(self, type_):
         return "HSTORE"
 
+    def visit_INT4RANGE(self, type_):
+        return "INT4RANGE"
+
+    def visit_INT8RANGE(self, type_):
+        return "INT8RANGE"
+
+    def visit_NUMRANGE(self, type_):
+        return "NUMRANGE"
+
+    def visit_DATERANGE(self, type_):
+        return "DATERANGE"
+
+    def visit_TSRANGE(self, type_):
+        return "TSRANGE"
+
+    def visit_TSTZRANGE(self, type_):
+        return "TSTZRANGE"
+
     def visit_datetime(self, type_):
         return self.visit_TIMESTAMP(type_)
 
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
new file mode 100644 (file)
index 0000000..b3a670d
--- /dev/null
@@ -0,0 +1,51 @@
+# Copyright (C) 2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from .base import ischema_names
+from ... import types as sqltypes
+
+__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE')
+
+class INT4RANGE(sqltypes.TypeEngine):
+    "Represent the Postgresql INT4RANGE type."
+    
+    __visit_name__ = 'INT4RANGE'
+
+ischema_names['int4range'] = INT4RANGE
+
+class INT8RANGE(sqltypes.TypeEngine):
+    "Represent the Postgresql INT8RANGE type."
+    
+    __visit_name__ = 'INT8RANGE'
+
+ischema_names['int8range'] = INT8RANGE
+
+class NUMRANGE(sqltypes.TypeEngine):
+    "Represent the Postgresql NUMRANGE type."
+    
+    __visit_name__ = 'NUMRANGE'
+
+ischema_names['numrange'] = NUMRANGE
+
+class DATERANGE(sqltypes.TypeEngine):
+    "Represent the Postgresql DATERANGE type."
+    
+    __visit_name__ = 'DATERANGE'
+
+ischema_names['daterange'] = DATERANGE
+
+class TSRANGE(sqltypes.TypeEngine):
+    "Represent the Postgresql TSRANGE type."
+    
+    __visit_name__ = 'TSRANGE'
+
+ischema_names['tsrange'] = TSRANGE
+
+class TSTZRANGE(sqltypes.TypeEngine):
+    "Represent the Postgresql TSTZRANGE type."
+    
+    __visit_name__ = 'TSTZRANGE'
+
+ischema_names['tstzrange'] = TSTZRANGE
index 00e5c07ab11c8c5976dbc40700ea7a28d76336e7..de37ffd7dc63ae367dc58304ef0e70f941c03f69 100644 (file)
@@ -17,7 +17,8 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
 from sqlalchemy.orm import Session, mapper, aliased
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
-from sqlalchemy.dialects.postgresql import HSTORE, hstore, array
+from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
+            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE
 import decimal
 from sqlalchemy import util
 from sqlalchemy.testing.util import round_decimal
@@ -3232,3 +3233,124 @@ class HStoreRoundTripTest(fixtures.TablesTest):
     def test_unicode_round_trip_native(self):
         engine = testing.db
         self._test_unicode_round_trip(engine)
+
+class _RangeTypeMixin(object):
+    __requires__ = 'range_types',
+    __dialect__ = 'postgresql+psycopg2'
+
+    @property
+    def extras(self):
+        # done this way so we don't get ImportErrors with
+        # older psycopg2 versions.
+        from psycopg2 import extras
+        return extras
+    
+    @classmethod
+    def define_tables(cls, metadata):
+        # no reason ranges shouldn't be primary keys,
+        # so lets just use them as such
+        Table('data_table', metadata,
+            Column('range', cls._col_type, primary_key=True),
+        )
+
+    def test_actual_type(self):
+        eq_(str(self._col_type()), self._col_str)
+        
+    def test_reflect(self):
+        from sqlalchemy import inspect
+        insp = inspect(testing.db)
+        cols = insp.get_columns('data_table')
+        assert isinstance(cols[0]['type'], self._col_type)
+
+    def _assert_data(self):
+        data = testing.db.execute(
+            select([self.tables.data_table.c.range])
+        ).fetchall()
+        eq_(data, [(self._data_obj(), )])
+
+    def test_insert_obj(self):
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_obj()}
+        )
+        self._assert_data()
+
+    def test_insert_text(self):
+        testing.db.engine.execute(
+            self.tables.data_table.insert(),
+            {'range': self._data_str}
+        )
+        self._assert_data()
+
+class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = INT4RANGE
+    _col_str = 'INT4RANGE'
+    _data_str = '[1,2)'
+    def _data_obj(self):
+        return self.extras.NumericRange(1, 2)
+
+class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = INT8RANGE
+    _col_str = 'INT8RANGE'
+    _data_str = '[9223372036854775806,9223372036854775807)'
+    def _data_obj(self):
+        return self.extras.NumericRange(
+            9223372036854775806, 9223372036854775807
+            )
+
+class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = NUMRANGE
+    _col_str = 'NUMRANGE'
+    _data_str = '[1.0,2.0)'
+    def _data_obj(self):
+        return self.extras.NumericRange(
+            decimal.Decimal('1.0'), decimal.Decimal('2.0')
+            )
+
+class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = DATERANGE
+    _col_str = 'DATERANGE'
+    _data_str = '[2013-03-23,2013-03-24)'
+    def _data_obj(self):
+        return self.extras.DateRange(
+            datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
+            )
+
+class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = TSRANGE
+    _col_str = 'TSRANGE'
+    _data_str = '[2013-03-23 14:30,2013-03-23 23:30)'
+    def _data_obj(self):
+        return self.extras.DateTimeRange(
+            datetime.datetime(2013, 3, 23, 14, 30),
+            datetime.datetime(2013, 3, 23, 23, 30)
+            )
+
+class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
+
+    _col_type = TSTZRANGE
+    _col_str = 'TSTZRANGE'
+
+    # make sure we use one, steady timestamp with timezone pair
+    # for all parts of all these tests
+    _tstzs = None
+    def tstzs(self):
+        if self._tstzs is None:
+            lower = testing.db.connect().scalar(
+                func.current_timestamp().select()
+                )
+            upper = lower+datetime.timedelta(1)
+            self._tstzs = (lower, upper)
+        return self._tstzs
+
+    @property
+    def _data_str(self):
+        return '[%s,%s)' % self.tstzs()
+    
+    def _data_obj(self):
+        return self.extras.DateTimeTZRange(*self.tstzs())
index 973ad9a10edb77bcaa0e8c7c5cfd5d4bd9c8b166..a24b84110883ba79476ef2254037a989a49873e6 100644 (file)
@@ -601,6 +601,21 @@ class DefaultRequirements(SuiteRequirements):
 
         return only_if(check_hstore)
 
+    @property
+    def range_types(self):
+        def check_range_types():
+            if not against("postgresql+psycopg2"):
+                return False
+            try:
+                self.db.execute("select '[1,2)'::int4range;")
+                # only supported in psycopg 2.5+
+                from psycopg2.extras import NumericRange
+                return True
+            except:
+                return False
+
+        return only_if(check_range_types)
+
     @property
     def sqlite(self):
         return skip_if(lambda: not self._has_sqlite())