# coding: utf-8
+from collections import defaultdict
import datetime
import decimal
from enum import Enum as _PY_Enum
from sqlalchemy import util
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import array
+from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
from sqlalchemy.dialects.postgresql import DATERANGE
from sqlalchemy.dialects.postgresql import DOMAIN
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.postgresql import hstore
+from sqlalchemy.dialects.postgresql import INT4MULTIRANGE
from sqlalchemy.dialects.postgresql import INT4RANGE
+from sqlalchemy.dialects.postgresql import INT8MULTIRANGE
from sqlalchemy.dialects.postgresql import INT8RANGE
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import NamedType
+from sqlalchemy.dialects.postgresql import NUMMULTIRANGE
from sqlalchemy.dialects.postgresql import NUMRANGE
+from sqlalchemy.dialects.postgresql import TSMULTIRANGE
from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
from sqlalchemy.dialects.postgresql import TSTZRANGE
from sqlalchemy.exc import CompileError
from sqlalchemy.orm import declarative_base
testing.combinations(
sqltypes.ARRAY,
postgresql.ARRAY,
- (_ArrayOfEnum, testing.requires.psycopg_compatibility),
+ (_ArrayOfEnum, testing.requires.any_psycopg_compatibility),
argnames="array_cls",
)(fn)
)
class _RangeTypeRoundTrip(fixtures.TablesTest):
- __requires__ = "range_types", "psycopg_compatibility"
+ __requires__ = "range_types", "any_psycopg_compatibility"
__backend__ = True
def extras(self):
pass
+class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
+ __dialect__ = "postgresql"
+
+ # operator tests
+
+ @classmethod
+ def setup_test_class(cls):
+ table = Table(
+ "data_table",
+ MetaData(),
+ Column("multirange", cls._col_type, primary_key=True),
+ )
+ cls.col = table.c.multirange
+
+ def _test_clause(self, colclause, expected, type_):
+ self.assert_compile(colclause, expected)
+ is_(colclause.type._type_affinity, type_._type_affinity)
+
+ def test_where_equal(self):
+ self._test_clause(
+ self.col == self._data_str(),
+ "data_table.multirange = %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_not_equal(self):
+ self._test_clause(
+ self.col != self._data_str(),
+ "data_table.multirange <> %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_is_null(self):
+ self._test_clause(
+ self.col == None,
+ "data_table.multirange IS NULL",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_is_not_null(self):
+ self._test_clause(
+ self.col != None,
+ "data_table.multirange IS NOT NULL",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_less_than(self):
+ self._test_clause(
+ self.col < self._data_str(),
+ "data_table.multirange < %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_greater_than(self):
+ self._test_clause(
+ self.col > self._data_str(),
+ "data_table.multirange > %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_less_than_or_equal(self):
+ self._test_clause(
+ self.col <= self._data_str(),
+ "data_table.multirange <= %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_where_greater_than_or_equal(self):
+ self._test_clause(
+ self.col >= self._data_str(),
+ "data_table.multirange >= %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_contains(self):
+ self._test_clause(
+ self.col.contains(self._data_str()),
+ "data_table.multirange @> %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_contained_by(self):
+ self._test_clause(
+ self.col.contained_by(self._data_str()),
+ "data_table.multirange <@ %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_overlaps(self):
+ self._test_clause(
+ self.col.overlaps(self._data_str()),
+ "data_table.multirange && %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_strictly_left_of(self):
+ self._test_clause(
+ self.col << self._data_str(),
+ "data_table.multirange << %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+ self._test_clause(
+ self.col.strictly_left_of(self._data_str()),
+ "data_table.multirange << %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_strictly_right_of(self):
+ self._test_clause(
+ self.col >> self._data_str(),
+ "data_table.multirange >> %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+ self._test_clause(
+ self.col.strictly_right_of(self._data_str()),
+ "data_table.multirange >> %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_not_extend_right_of(self):
+ self._test_clause(
+ self.col.not_extend_right_of(self._data_str()),
+ "data_table.multirange &< %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_not_extend_left_of(self):
+ self._test_clause(
+ self.col.not_extend_left_of(self._data_str()),
+ "data_table.multirange &> %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_adjacent_to(self):
+ self._test_clause(
+ self.col.adjacent_to(self._data_str()),
+ "data_table.multirange -|- %(multirange_1)s",
+ sqltypes.BOOLEANTYPE,
+ )
+
+ def test_union(self):
+ self._test_clause(
+ self.col + self.col,
+ "data_table.multirange + data_table.multirange",
+ self.col.type,
+ )
+
+ def test_intersection(self):
+ self._test_clause(
+ self.col * self.col,
+ "data_table.multirange * data_table.multirange",
+ self.col.type,
+ )
+
+ def test_different(self):
+ self._test_clause(
+ self.col - self.col,
+ "data_table.multirange - data_table.multirange",
+ self.col.type,
+ )
+
+
+class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
+ __requires__ = "range_types", "psycopg_only_compatibility"
+ __backend__ = True
+
+ def extras(self):
+ # done this way so we don't get ImportErrors with
+ # older psycopg2 versions.
+ if testing.against("postgresql+psycopg"):
+ from psycopg.types.range import Range
+ from psycopg.types.multirange import Multirange
+
+ class psycopg_extras:
+ def __init__(self):
+ self.data = defaultdict(
+ lambda: Range, Multirange=Multirange
+ )
+
+ def __getattr__(self, name):
+ return self.data[name]
+
+ extras = psycopg_extras()
+ else:
+ assert False, "Unsupported MultiRange Dialect"
+ return extras
+
+ @classmethod
+ def define_tables(cls, metadata):
+ # no reason ranges shouldn't be primary keys,
+ # so lets just use them as such
+ table = Table(
+ "data_table",
+ metadata,
+ Column("range", cls._col_type, primary_key=True),
+ )
+ cls.col = table.c.range
+
+ def test_actual_type(self):
+ eq_(str(self._col_type()), self._col_str)
+
+ def test_reflect(self, connection):
+ from sqlalchemy import inspect
+
+ insp = inspect(connection)
+ cols = insp.get_columns("data_table")
+ assert isinstance(cols[0]["type"], self._col_type)
+
+ def _assert_data(self, conn):
+ data = conn.execute(select(self.tables.data_table.c.range)).fetchall()
+ eq_(data, [(self._data_obj(),)])
+
+ def test_insert_obj(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(), {"range": self._data_obj()}
+ )
+ self._assert_data(connection)
+
+ def test_insert_text(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(), {"range": self._data_str()}
+ )
+ self._assert_data(connection)
+
+ def test_union_result(self, connection):
+ # insert
+ connection.execute(
+ self.tables.data_table.insert(), {"range": self._data_str()}
+ )
+ # select
+ range_ = self.tables.data_table.c.range
+ data = connection.execute(select(range_ + range_)).fetchall()
+ eq_(data, [(self._data_obj(),)])
+
+ def test_intersection_result(self, connection):
+ # insert
+ connection.execute(
+ self.tables.data_table.insert(), {"range": self._data_str()}
+ )
+ # select
+ range_ = self.tables.data_table.c.range
+ data = connection.execute(select(range_ * range_)).fetchall()
+ eq_(data, [(self._data_obj(),)])
+
+ def test_difference_result(self, connection):
+ # insert
+ connection.execute(
+ self.tables.data_table.insert(), {"range": self._data_str()}
+ )
+ # select
+ range_ = self.tables.data_table.c.range
+ data = connection.execute(select(range_ - range_)).fetchall()
+ eq_(data, [(self.extras().Multirange(),)])
+
+
+class _Int4MultiRangeTests:
+
+ _col_type = INT4MULTIRANGE
+ _col_str = "INT4MULTIRANGE"
+
+ def _data_str(self):
+ return "{[1,2), [3, 5), [9, 12)}"
+
+ def _data_obj(self):
+ return self.extras().Multirange(
+ [
+ self.extras().Range(1, 2),
+ self.extras().Range(3, 5),
+ self.extras().Range(9, 12),
+ ]
+ )
+
+
+class _Int8MultiRangeTests:
+
+ _col_type = INT8MULTIRANGE
+ _col_str = "INT8MULTIRANGE"
+
+ def _data_str(self):
+ return (
+ "{[9223372036854775801,9223372036854775803),"
+ + "[9223372036854775805,9223372036854775807)}"
+ )
+
+ def _data_obj(self):
+ return self.extras().Multirange(
+ [
+ self.extras().Range(9223372036854775801, 9223372036854775803),
+ self.extras().Range(9223372036854775805, 9223372036854775807),
+ ]
+ )
+
+
+class _NumMultiRangeTests:
+
+ _col_type = NUMMULTIRANGE
+ _col_str = "NUMMULTIRANGE"
+
+ def _data_str(self):
+ return "{[1.0,2.0), [3.0, 5.0), [9.0, 12.0)}"
+
+ def _data_obj(self):
+ return self.extras().Multirange(
+ [
+ self.extras().Range(
+ decimal.Decimal("1.0"), decimal.Decimal("2.0")
+ ),
+ self.extras().Range(
+ decimal.Decimal("3.0"), decimal.Decimal("5.0")
+ ),
+ self.extras().Range(
+ decimal.Decimal("9.0"), decimal.Decimal("12.0")
+ ),
+ ]
+ )
+
+
+class _DateMultiRangeTests:
+
+ _col_type = DATEMULTIRANGE
+ _col_str = "DATEMULTIRANGE"
+
+ def _data_str(self):
+ return "{[2013-03-23,2013-03-24), [2014-05-23,2014-05-24)}"
+
+ def _data_obj(self):
+ return self.extras().Multirange(
+ [
+ self.extras().Range(
+ datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
+ ),
+ self.extras().Range(
+ datetime.date(2014, 5, 23), datetime.date(2014, 5, 24)
+ ),
+ ]
+ )
+
+
+class _DateTimeMultiRangeTests:
+
+ _col_type = TSMULTIRANGE
+ _col_str = "TSMULTIRANGE"
+
+ def _data_str(self):
+ return (
+ "{[2013-03-23 14:30,2013-03-23 23:30),"
+ + "[2014-05-23 14:30,2014-05-23 23:30)}"
+ )
+
+ def _data_obj(self):
+ return self.extras().Multirange(
+ [
+ self.extras().Range(
+ datetime.datetime(2013, 3, 23, 14, 30),
+ datetime.datetime(2013, 3, 23, 23, 30),
+ ),
+ self.extras().Range(
+ datetime.datetime(2014, 5, 23, 14, 30),
+ datetime.datetime(2014, 5, 23, 23, 30),
+ ),
+ ]
+ )
+
+
+class _DateTimeTZMultiRangeTests:
+
+ _col_type = TSTZMULTIRANGE
+ _col_str = "TSTZMULTIRANGE"
+
+ # make sure we use one, steady timestamp with timezone pair
+ # for all parts of all these tests
+ _tstzs = None
+ _tstzs_delta = None
+
+ def tstzs(self):
+ if self._tstzs is None:
+ with testing.db.connect() as connection:
+ lower = connection.scalar(func.current_timestamp().select())
+ upper = lower + datetime.timedelta(1)
+ self._tstzs = (lower, upper)
+ return self._tstzs
+
+ def tstzs_delta(self):
+ if self._tstzs_delta is None:
+ with testing.db.connect() as connection:
+ lower = connection.scalar(
+ func.current_timestamp().select()
+ ) + datetime.timedelta(3)
+ upper = lower + datetime.timedelta(2)
+ self._tstzs_delta = (lower, upper)
+ return self._tstzs_delta
+
+ def _data_str(self):
+ tstzs_lower, tstzs_upper = self.tstzs()
+ tstzs_delta_lower, tstzs_delta_upper = self.tstzs_delta()
+ return "{{[{tl},{tu}), [{tdl},{tdu})}}".format(
+ tl=tstzs_lower,
+ tu=tstzs_upper,
+ tdl=tstzs_delta_lower,
+ tdu=tstzs_delta_upper,
+ )
+
+ def _data_obj(self):
+ return self.extras().Multirange(
+ [
+ self.extras().Range(*self.tstzs()),
+ self.extras().Range(*self.tstzs_delta()),
+ ]
+ )
+
+
+class Int4MultiRangeCompilationTest(
+ _Int4MultiRangeTests, _MultiRangeTypeCompilation
+):
+ pass
+
+
+class Int4MultiRangeRoundTripTest(
+ _Int4MultiRangeTests, _MultiRangeTypeRoundTrip
+):
+ pass
+
+
+class Int8MultiRangeCompilationTest(
+ _Int8MultiRangeTests, _MultiRangeTypeCompilation
+):
+ pass
+
+
+class Int8MultiRangeRoundTripTest(
+ _Int8MultiRangeTests, _MultiRangeTypeRoundTrip
+):
+ pass
+
+
+class NumMultiRangeCompilationTest(
+ _NumMultiRangeTests, _MultiRangeTypeCompilation
+):
+ pass
+
+
+class NumMultiRangeRoundTripTest(
+ _NumMultiRangeTests, _MultiRangeTypeRoundTrip
+):
+ pass
+
+
+class DateMultiRangeCompilationTest(
+ _DateMultiRangeTests, _MultiRangeTypeCompilation
+):
+ pass
+
+
+class DateMultiRangeRoundTripTest(
+ _DateMultiRangeTests, _MultiRangeTypeRoundTrip
+):
+ pass
+
+
+class DateTimeMultiRangeCompilationTest(
+ _DateTimeMultiRangeTests, _MultiRangeTypeCompilation
+):
+ pass
+
+
+class DateTimeMultiRangeRoundTripTest(
+ _DateTimeMultiRangeTests, _MultiRangeTypeRoundTrip
+):
+ pass
+
+
+class DateTimeTZMultiRangeCompilationTest(
+ _DateTimeTZMultiRangeTests, _MultiRangeTypeCompilation
+):
+ pass
+
+
+class DateTimeTZRMultiangeRoundTripTest(
+ _DateTimeTZMultiRangeTests, _MultiRangeTypeRoundTrip
+):
+ pass
+
+
class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"