From: mike bayer Date: Sat, 22 Jun 2013 14:47:02 +0000 (-0700) Subject: Merge pull request #5 from cjw296/pg-ranges X-Git-Tag: rel_0_8_2~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f52d45672b0206a866258aa5291096816c8563b4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Merge pull request #5 from cjw296/pg-ranges Support for Postgres range types. --- diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index df141cce00..f2b401e9a5 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -16,7 +16,8 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect:: ARRAY, BIGINT, BIT, BOOLEAN, BYTEA, CHAR, CIDR, DATE, \ DOUBLE_PRECISION, ENUM, FLOAT, HSTORE, INET, INTEGER, \ INTERVAL, MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, \ - TIMESTAMP, UUID, VARCHAR + TIMESTAMP, UUID, VARCHAR, INT4RANGE, INT8RANGE, NUMRANGE, \ + DATERANGE, TSRANGE, TSTZRANGE Types which are specific to PostgreSQL, or have PostgreSQL-specific construction arguments, are as follows: @@ -81,6 +82,54 @@ construction arguments, are as follows: :members: __init__ :show-inheritance: +.. autoclass:: sqlalchemy.dialects.postgresql.ranges.RangeOperators + :members: + +.. autoclass:: INT4RANGE + :show-inheritance: + +.. autoclass:: INT8RANGE + :show-inheritance: + +.. autoclass:: NUMRANGE + :show-inheritance: + +.. autoclass:: DATERANGE + :show-inheritance: + +.. autoclass:: TSRANGE + :show-inheritance: + +.. autoclass:: TSTZRANGE + :show-inheritance: + + +PostgreSQL Constraint Types +--------------------------- + +SQLAlchemy supports Postgresql EXCLUDE constraints via the +:class:`ExcludeConstraint` class: + +.. autoclass:: ExcludeConstraint + :show-inheritance: + :members: __init__ + +For example:: + + from sqlalchemy.dialects.postgresql import ( + ExcludeConstraint, + TSRANGE as Range, + ) + + class RoomBookings(Base): + + room = Column(Integer(), primary_key=True) + during = Column(TSRANGE()) + + __table_args__ = ( + ExcludeConstraint(('room', '='), ('during', '&&')), + ) + psycopg2 -------------- diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index d0f785bdda..408b678467 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -12,12 +12,16 @@ from .base import \ INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \ DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All +from .constraints import ExcludeConstraint from .hstore import HSTORE, hstore +from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ + TSTZRANGE __all__ = ( 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE', - 'hstore' + 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', + 'TSRANGE', 'TSTZRANGE' ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 8c9af5bcf6..e1fc9e2111 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -444,7 +444,7 @@ class array(expression.Tuple): An instance of :class:`.array` will always have the datatype :class:`.ARRAY`. The "inner" type of the array is inferred from - the values present, unless the "type_" keyword argument is passed:: + the values present, unless the ``type_`` keyword argument is passed:: array(['foo', 'bar'], type_=CHAR) @@ -1142,6 +1142,22 @@ class PGDDLCompiler(compiler.DDLCompiler): text += " WHERE " + where_compiled return text + def visit_exclude_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + elements = [] + for c in constraint.columns: + op = constraint.operators[c.name] + elements.append(self.preparer.quote(c.name, c.quote)+' WITH '+op) + text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements)) + if constraint.where is not None: + sqltext = sql_util.expression_as_ddl(constraint.where) + text += ' WHERE (%s)' % self.sql_compiler.process(sqltext) + text += self.define_constraint_deferrability(constraint) + return text + class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_INET(self, type_): @@ -1168,6 +1184,24 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_HSTORE(self, type_): return "HSTORE" + def visit_INT4RANGE(self, type_): + return "INT4RANGE" + + def visit_INT8RANGE(self, type_): + return "INT8RANGE" + + def visit_NUMRANGE(self, type_): + return "NUMRANGE" + + def visit_DATERANGE(self, type_): + return "DATERANGE" + + def visit_TSRANGE(self, type_): + return "TSRANGE" + + def visit_TSTZRANGE(self, type_): + return "TSTZRANGE" + def visit_datetime(self, type_): return self.visit_TIMESTAMP(type_) diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/constraints.py new file mode 100644 index 0000000000..5b8bbe6430 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/constraints.py @@ -0,0 +1,73 @@ +# Copyright (C) 2013 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +from sqlalchemy.schema import ColumnCollectionConstraint +from sqlalchemy.sql import expression + +class ExcludeConstraint(ColumnCollectionConstraint): + """A table-level EXCLUDE constraint. + + Defines an EXCLUDE constraint as described in the `postgres + documentation`__. + + __ http://www.postgresql.org/docs/9.0/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE + """ + + __visit_name__ = 'exclude_constraint' + + where = None + + def __init__(self, *elements, **kw): + """ + :param \*elements: + A sequence of two tuples of the form ``(column, operator)`` where + column must be a column name or Column object and operator must + be a string containing the operator to use. + + :param name: + Optional, the in-database name of this constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + :param using: + Optional string. If set, emit USING when issuing DDL + for this constraint. Defaults to 'gist'. + + :param where: + Optional string. If set, emit WHERE when issuing DDL + for this constraint. + + """ + ColumnCollectionConstraint.__init__( + self, + *[col for col, op in elements], + name=kw.get('name'), + deferrable=kw.get('deferrable'), + initially=kw.get('initially') + ) + self.operators = {} + for col_or_string, op in elements: + name = getattr(col_or_string, 'name', col_or_string) + self.operators[name] = op + self.using = kw.get('using', 'gist') + where = kw.get('where') + if where: + self.where = expression._literal_as_text(where) + + def copy(self, **kw): + elements = [(col, self.operators[col]) + for col in self.columns.keys()] + c = self.__class__(*elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially) + c.dispatch._update(self.dispatch) + return c + diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py new file mode 100644 index 0000000000..e7ab1d5b53 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -0,0 +1,133 @@ +# Copyright (C) 2013 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .base import ischema_names +from ... import types as sqltypes + +__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE') + +class RangeOperators(object): + """ + This mixin provides functionality for the Range Operators + listed in Table 9-44 of the `postgres documentation`__ for Range + Functions and Operators. It is used by all the range types + provided in the ``postgres`` dialect and can likely be used for + any range types you create yourself. + + __ http://www.postgresql.org/docs/devel/static/functions-range.html + + No extra support is provided for the Range Functions listed in + Table 9-45 of the postgres documentation. For these, the normal + :func:`~sqlalchemy.sql.expression.func` object should be used. + """ + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for range types.""" + + def __ne__(self, other): + "Boolean expression. Returns true if two ranges are not equal" + return self.expr.op('<>')(other) + + def contains(self, other, **kw): + """Boolean expression. Returns true if the right hand operand, + which can be an element or a range, is contained within the + column. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Returns true if the column is contained + within the right hand operand. + """ + return self.expr.op('<@')(other) + + def overlaps(self, other): + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + return self.expr.op('&&')(other) + + def strictly_left_of(self, other): + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + return self.expr.op('<<')(other) + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other): + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + return self.expr.op('>>')(other) + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + return self.expr.op('&<')(other) + + def not_extend_left_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + return self.expr.op('&>')(other) + + def adjacent_to(self, other): + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + return self.expr.op('-|-')(other) + + def __add__(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contigous. + """ + return self.expr.op('+')(other) + +class INT4RANGE(RangeOperators, sqltypes.TypeEngine): + "Represent the Postgresql INT4RANGE type." + + __visit_name__ = 'INT4RANGE' + +ischema_names['int4range'] = INT4RANGE + +class INT8RANGE(RangeOperators, sqltypes.TypeEngine): + "Represent the Postgresql INT8RANGE type." + + __visit_name__ = 'INT8RANGE' + +ischema_names['int8range'] = INT8RANGE + +class NUMRANGE(RangeOperators, sqltypes.TypeEngine): + "Represent the Postgresql NUMRANGE type." + + __visit_name__ = 'NUMRANGE' + +ischema_names['numrange'] = NUMRANGE + +class DATERANGE(RangeOperators, sqltypes.TypeEngine): + "Represent the Postgresql DATERANGE type." + + __visit_name__ = 'DATERANGE' + +ischema_names['daterange'] = DATERANGE + +class TSRANGE(RangeOperators, sqltypes.TypeEngine): + "Represent the Postgresql TSRANGE type." + + __visit_name__ = 'TSRANGE' + +ischema_names['tsrange'] = TSRANGE + +class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): + "Represent the Postgresql TSTZRANGE type." + + __visit_name__ = 'TSTZRANGE' + +ischema_names['tstzrange'] = TSTZRANGE diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 43c4719190..1389fe5f87 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -17,7 +17,9 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \ from sqlalchemy.orm import Session, mapper, aliased from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql -from sqlalchemy.dialects.postgresql import HSTORE, hstore, array +from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \ + INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \ + ExcludeConstraint import decimal from sqlalchemy import util from sqlalchemy.testing.util import round_decimal @@ -182,6 +184,53 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'USING hash (data)', dialect=postgresql.dialect()) + def test_exclude_constraint_min(self): + m = MetaData() + tbl = Table('testtbl', m, + Column('room', Integer, primary_key=True)) + cons = ExcludeConstraint(('room', '=')) + tbl.append_constraint(cons) + self.assert_compile(schema.AddConstraint(cons), + 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' + '(room WITH =)', + dialect=postgresql.dialect()) + + def test_exclude_constraint_full(self): + m = MetaData() + room = Column('room', Integer, primary_key=True) + tbl = Table('testtbl', m, + room, + Column('during', TSRANGE)) + room = Column('room', Integer, primary_key=True) + cons = ExcludeConstraint((room, '='), ('during', '&&'), + name='my_name', + using='gist', + where="room > 100", + deferrable=True, + initially='immediate') + tbl.append_constraint(cons) + self.assert_compile(schema.AddConstraint(cons), + 'ALTER TABLE testtbl ADD CONSTRAINT my_name ' + 'EXCLUDE USING gist ' + '(room WITH =, during WITH ''&&) WHERE ' + '(room > 100) DEFERRABLE INITIALLY immediate', + dialect=postgresql.dialect()) + + def test_exclude_constraint_copy(self): + m = MetaData() + cons = ExcludeConstraint(('room', '=')) + tbl = Table('testtbl', m, + Column('room', Integer, primary_key=True), + cons) + # apparently you can't copy a ColumnCollectionConstraint until + # after it has been bound to a table... + cons_copy = cons.copy() + tbl.append_constraint(cons_copy) + self.assert_compile(schema.AddConstraint(cons_copy), + 'ALTER TABLE testtbl ADD EXCLUDE USING gist ' + '(room WITH =)', + dialect=postgresql.dialect()) + def test_substring(self): self.assert_compile(func.substring('abc', 1, 2), 'SUBSTRING(%(substring_1)s FROM %(substring_2)s ' @@ -3241,3 +3290,282 @@ class HStoreRoundTripTest(fixtures.TablesTest): def test_unicode_round_trip_native(self): engine = testing.db self._test_unicode_round_trip(engine) + +class _RangeTypeMixin(object): + __requires__ = 'range_types', + __dialect__ = 'postgresql+psycopg2' + + @property + def extras(self): + # done this way so we don't get ImportErrors with + # older psycopg2 versions. + from psycopg2 import extras + return extras + + @classmethod + def define_tables(cls, metadata): + # no reason ranges shouldn't be primary keys, + # so lets just use them as such + table = Table('data_table', metadata, + Column('range', cls._col_type, primary_key=True), + ) + cls.col = table.c.range + + def test_actual_type(self): + eq_(str(self._col_type()), self._col_str) + + def test_reflect(self): + from sqlalchemy import inspect + insp = inspect(testing.db) + cols = insp.get_columns('data_table') + assert isinstance(cols[0]['type'], self._col_type) + + def _assert_data(self): + data = testing.db.execute( + select([self.tables.data_table.c.range]) + ).fetchall() + eq_(data, [(self._data_obj(), )]) + + def test_insert_obj(self): + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_obj()} + ) + self._assert_data() + + def test_insert_text(self): + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + self._assert_data() + + # operator tests + + def _test_clause(self, colclause, expected): + dialect = postgresql.dialect() + compiled = str(colclause.compile(dialect=dialect)) + eq_(compiled, expected) + + def test_where_equal(self): + self._test_clause( + self.col==self._data_str, + "data_table.range = %(range_1)s" + ) + + def test_where_not_equal(self): + self._test_clause( + self.col!=self._data_str, + "data_table.range <> %(range_1)s" + ) + + def test_where_less_than(self): + self._test_clause( + self.col < self._data_str, + "data_table.range < %(range_1)s" + ) + + def test_where_greater_than(self): + self._test_clause( + self.col > self._data_str, + "data_table.range > %(range_1)s" + ) + + def test_where_less_than_or_equal(self): + self._test_clause( + self.col <= self._data_str, + "data_table.range <= %(range_1)s" + ) + + def test_where_greater_than_or_equal(self): + self._test_clause( + self.col >= self._data_str, + "data_table.range >= %(range_1)s" + ) + + def test_contains(self): + self._test_clause( + self.col.contains(self._data_str), + "data_table.range @> %(range_1)s" + ) + + def test_contained_by(self): + self._test_clause( + self.col.contained_by(self._data_str), + "data_table.range <@ %(range_1)s" + ) + + def test_overlaps(self): + self._test_clause( + self.col.overlaps(self._data_str), + "data_table.range && %(range_1)s" + ) + + def test_strictly_left_of(self): + self._test_clause( + self.col << self._data_str, + "data_table.range << %(range_1)s" + ) + self._test_clause( + self.col.strictly_left_of(self._data_str), + "data_table.range << %(range_1)s" + ) + + def test_strictly_right_of(self): + self._test_clause( + self.col >> self._data_str, + "data_table.range >> %(range_1)s" + ) + self._test_clause( + self.col.strictly_right_of(self._data_str), + "data_table.range >> %(range_1)s" + ) + + def test_not_extend_right_of(self): + self._test_clause( + self.col.not_extend_right_of(self._data_str), + "data_table.range &< %(range_1)s" + ) + + def test_not_extend_left_of(self): + self._test_clause( + self.col.not_extend_left_of(self._data_str), + "data_table.range &> %(range_1)s" + ) + + def test_adjacent_to(self): + self._test_clause( + self.col.adjacent_to(self._data_str), + "data_table.range -|- %(range_1)s" + ) + + def test_union(self): + self._test_clause( + self.col + self.col, + "data_table.range + data_table.range" + ) + + def test_union_result(self): + # insert + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + # select + range = self.tables.data_table.c.range + data = testing.db.execute( + select([range + range]) + ).fetchall() + eq_(data, [(self._data_obj(), )]) + + + def test_intersection(self): + self._test_clause( + self.col * self.col, + "data_table.range * data_table.range" + ) + + def test_intersection_result(self): + # insert + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + # select + range = self.tables.data_table.c.range + data = testing.db.execute( + select([range * range]) + ).fetchall() + eq_(data, [(self._data_obj(), )]) + + def test_different(self): + self._test_clause( + self.col - self.col, + "data_table.range - data_table.range" + ) + + def test_difference_result(self): + # insert + testing.db.engine.execute( + self.tables.data_table.insert(), + {'range': self._data_str} + ) + # select + range = self.tables.data_table.c.range + data = testing.db.execute( + select([range - range]) + ).fetchall() + eq_(data, [(self._data_obj().__class__(empty=True), )]) + +class Int4RangeTests(_RangeTypeMixin, fixtures.TablesTest): + + _col_type = INT4RANGE + _col_str = 'INT4RANGE' + _data_str = '[1,2)' + def _data_obj(self): + return self.extras.NumericRange(1, 2) + +class Int8RangeTests(_RangeTypeMixin, fixtures.TablesTest): + + _col_type = INT8RANGE + _col_str = 'INT8RANGE' + _data_str = '[9223372036854775806,9223372036854775807)' + def _data_obj(self): + return self.extras.NumericRange( + 9223372036854775806, 9223372036854775807 + ) + +class NumRangeTests(_RangeTypeMixin, fixtures.TablesTest): + + _col_type = NUMRANGE + _col_str = 'NUMRANGE' + _data_str = '[1.0,2.0)' + def _data_obj(self): + return self.extras.NumericRange( + decimal.Decimal('1.0'), decimal.Decimal('2.0') + ) + +class DateRangeTests(_RangeTypeMixin, fixtures.TablesTest): + + _col_type = DATERANGE + _col_str = 'DATERANGE' + _data_str = '[2013-03-23,2013-03-24)' + def _data_obj(self): + return self.extras.DateRange( + datetime.date(2013, 3, 23), datetime.date(2013, 3, 24) + ) + +class DateTimeRangeTests(_RangeTypeMixin, fixtures.TablesTest): + + _col_type = TSRANGE + _col_str = 'TSRANGE' + _data_str = '[2013-03-23 14:30,2013-03-23 23:30)' + def _data_obj(self): + return self.extras.DateTimeRange( + datetime.datetime(2013, 3, 23, 14, 30), + datetime.datetime(2013, 3, 23, 23, 30) + ) + +class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest): + + _col_type = TSTZRANGE + _col_str = 'TSTZRANGE' + + # make sure we use one, steady timestamp with timezone pair + # for all parts of all these tests + _tstzs = None + def tstzs(self): + if self._tstzs is None: + lower = testing.db.connect().scalar( + func.current_timestamp().select() + ) + upper = lower+datetime.timedelta(1) + self._tstzs = (lower, upper) + return self._tstzs + + @property + def _data_str(self): + return '[%s,%s)' % self.tstzs() + + def _data_obj(self): + return self.extras.DateTimeTZRange(*self.tstzs()) diff --git a/test/requirements.py b/test/requirements.py index 42128669d8..807f21baca 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -585,6 +585,21 @@ class DefaultRequirements(SuiteRequirements): return only_if(check_hstore) + @property + def range_types(self): + def check_range_types(): + if not against("postgresql+psycopg2"): + return False + try: + self.db.execute("select '[1,2)'::int4range;") + # only supported in psycopg 2.5+ + from psycopg2.extras import NumericRange + return True + except: + return False + + return only_if(check_range_types) + @property def sqlite(self): return skip_if(lambda: not self._has_sqlite())