]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement EXCLUDE constraints for postgres.
authorChris Withers <chris@simplistix.co.uk>
Tue, 21 May 2013 20:11:35 +0000 (21:11 +0100)
committerChris Withers <chris@simplistix.co.uk>
Mon, 10 Jun 2013 11:09:56 +0000 (12:09 +0100)
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/constraints.py [new file with mode: 0644]
test/dialect/test_postgresql.py

index 3c259671d94558c75d979889dbeeba449bedbf38..408b678467ded92201554db7dc0013f8f9b01e2e 100644 (file)
@@ -12,6 +12,7 @@ from .base import \
     INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
     INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
     DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
+from .constraints import ExcludeConstraint
 from .hstore import HSTORE, hstore
 from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
     TSTZRANGE
index 127e1130b193191dab176c7b4daac13710f4c380..4a6de0ceb96f8e684b7872d21d338a7cf6e33b4c 100644 (file)
@@ -1124,6 +1124,22 @@ class PGDDLCompiler(compiler.DDLCompiler):
             text += " WHERE " + where_compiled
         return text
 
+    def visit_exclude_constraint(self, constraint):
+        text = ""
+        if constraint.name is not None:
+            text += "CONSTRAINT %s " % \
+                    self.preparer.format_constraint(constraint)
+        elements = []
+        for c in constraint.columns:
+            op = constraint.operators[c.name]
+            elements.append(self.preparer.quote(c.name, c.quote)+' WITH '+op)
+        text += "EXCLUDE USING %s (%s)" % (constraint.using, ', '.join(elements))
+        if constraint.where is not None:
+            sqltext = sql_util.expression_as_ddl(constraint.where)
+            text += ' WHERE (%s)' % self.sql_compiler.process(sqltext)
+        text += self.define_constraint_deferrability(constraint)
+        return text
+
 
 class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_INET(self, type_):
diff --git a/lib/sqlalchemy/dialects/postgresql/constraints.py b/lib/sqlalchemy/dialects/postgresql/constraints.py
new file mode 100644 (file)
index 0000000..88d688a
--- /dev/null
@@ -0,0 +1,73 @@
+# Copyright (C) 2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+from sqlalchemy.schema import ColumnCollectionConstraint
+from sqlalchemy.sql import expression
+
+class ExcludeConstraint(ColumnCollectionConstraint):
+    """A table-level UNIQUE constraint.
+
+    Defines a single column or composite UNIQUE constraint. For a no-frills,
+    single column constraint, adding ``unique=True`` to the ``Column``
+    definition is a shorthand equivalent for an unnamed, single column
+    UniqueConstraint.
+    """
+
+    __visit_name__ = 'exclude_constraint'
+
+    where = None
+
+    def __init__(self, *elements, **kw):
+        """
+        :param \*elements:
+          A sequence of two tuples of the form ``(column, operator)`` where
+          column must be a column name or Column object and operator must
+          be a string containing the operator to use.
+
+        :param name:
+          Optional, the in-database name of this constraint.
+
+        :param deferrable:
+          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
+          issuing DDL for this constraint.
+
+        :param initially:
+          Optional string.  If set, emit INITIALLY <value> when issuing DDL
+          for this constraint.
+
+        :param using:
+          Optional string.  If set, emit USING <index_method> when issuing DDL
+          for this constraint. Defaults to 'gist'.
+          
+        :param where:
+          Optional string.  If set, emit WHERE <predicate> when issuing DDL
+          for this constraint.
+
+        """
+        ColumnCollectionConstraint.__init__(
+            self,
+            *[col for col, op in elements],
+            name=kw.get('name'),
+            deferrable=kw.get('deferrable'),
+            initially=kw.get('initially')
+            )
+        self.operators = {}
+        for col_or_string, op in elements:
+            name = getattr(col_or_string, 'name', col_or_string)
+            self.operators[name] = op
+        self.using = kw.get('using', 'gist')
+        where = kw.get('where')
+        if where:
+            self.where =  expression._literal_as_text(where)
+            
+    def copy(self, **kw):
+        elements = [(col, self.operators[col])
+                    for col in self.columns.keys()]
+        c = self.__class__(*elements,
+                            name=self.name,
+                            deferrable=self.deferrable,
+                            initially=self.initially)
+        c.dispatch._update(self.dispatch)
+        return c
+
index 70b683b088200c17bcb2539ee69ab8e9de80de3c..2203d93451d413a14ce344bf0909c64edd6e8303 100644 (file)
@@ -18,7 +18,8 @@ from sqlalchemy.orm import Session, mapper, aliased
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
-            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE
+            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
+            ExcludeConstraint
 import decimal
 from sqlalchemy import util
 from sqlalchemy.testing.util import round_decimal
@@ -183,6 +184,53 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             'USING hash (data)',
                             dialect=postgresql.dialect())
 
+    def test_exclude_constraint_min(self):
+        m = MetaData()
+        tbl = Table('testtbl', m, 
+                    Column('room', Integer, primary_key=True))
+        cons = ExcludeConstraint(('room', '='))
+        tbl.append_constraint(cons)
+        self.assert_compile(schema.AddConstraint(cons),
+                            'ALTER TABLE testtbl ADD EXCLUDE USING gist '
+                            '(room WITH =)',
+                            dialect=postgresql.dialect())
+
+    def test_exclude_constraint_full(self):
+        m = MetaData()
+        room = Column('room', Integer, primary_key=True)
+        tbl = Table('testtbl', m,
+                    room,
+                    Column('during', TSRANGE))
+        room = Column('room', Integer, primary_key=True)
+        cons = ExcludeConstraint((room, '='), ('during', '&&'),
+                                 name='my_name',
+                                 using='gist',
+                                 where="room > 100",
+                                 deferrable=True,
+                                 initially='immediate')
+        tbl.append_constraint(cons)
+        self.assert_compile(schema.AddConstraint(cons),
+                            'ALTER TABLE testtbl ADD CONSTRAINT my_name '
+                            'EXCLUDE USING gist '
+                            '(room WITH =, during WITH ''&&) WHERE '
+                            '(room > 100) DEFERRABLE INITIALLY immediate',
+                            dialect=postgresql.dialect())
+
+    def test_exclude_constraint_copy(self):
+        m = MetaData()
+        cons = ExcludeConstraint(('room', '='))
+        tbl = Table('testtbl', m, 
+              Column('room', Integer, primary_key=True),
+              cons)
+        # apparently you can't copy a ColumnCollectionConstraint until
+        # after it has been bound to a table...
+        cons_copy = cons.copy()
+        tbl.append_constraint(cons_copy)
+        self.assert_compile(schema.AddConstraint(cons_copy),
+                            'ALTER TABLE testtbl ADD EXCLUDE USING gist '
+                            '(room WITH =)',
+                            dialect=postgresql.dialect())
+
     def test_substring(self):
         self.assert_compile(func.substring('abc', 1, 2),
                             'SUBSTRING(%(substring_1)s FROM %(substring_2)s '