]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added deferrability support to constraints
authorJason Kirtland <jek@discorporate.us>
Fri, 8 Feb 2008 20:50:33 +0000 (20:50 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 8 Feb 2008 20:50:33 +0000 (20:50 +0000)
CHANGES
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
test/orm/unitofwork.py
test/sql/constraints.py

diff --git a/CHANGES b/CHANGES
index d699e3384cc7f1f12a836ba2dc69fb9703127670..50df3a03031d841f77bc1c3c52b9f7ffb39ff2d0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -29,6 +29,8 @@ CHANGES
     - cast() accepts text('something') and other non-literal
       operands properly [ticket:962]
 
+    - Deferrable constraints can now be defined.
+
     - added "autocommit=True" kwarg to select() and text(), as
       well as generative autocommit() method on select(); for
       statements which modify the database through some
index 64e9d203d79e1df2e2913339109e113ab22b2089..83f282b2464633155230b7dec10f09d1a788aef9 100644 (file)
@@ -608,7 +608,7 @@ class ForeignKey(SchemaItem):
     constraint definition.
     """
 
-    def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None):
+    def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None):
         """Construct a new ``ForeignKey`` object.
 
         column
@@ -629,6 +629,8 @@ class ForeignKey(SchemaItem):
         self.name = name
         self.onupdate = onupdate
         self.ondelete = ondelete
+        self.deferrable = deferrable
+        self.initially = initially
 
     def __repr__(self):
         return "ForeignKey(%s)" % repr(self._get_colspec())
@@ -714,7 +716,7 @@ class ForeignKey(SchemaItem):
                     self.parent.table.constraints.remove(fk.constraint)
 
         if self.constraint is None and isinstance(self.parent.table, Table):
-            self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete)
+            self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, deferrable=self.deferrable, initially=self.initially)
             self.parent.table.append_constraint(self.constraint)
             self.constraint._append_fk(self)
 
@@ -855,9 +857,11 @@ class Constraint(SchemaItem):
     list of underying columns.
     """
 
-    def __init__(self, name=None):
+    def __init__(self, name=None, deferrable=None, initially=None):
         self.name = name
         self.columns = expression.ColumnCollection()
+        self.deferrable = deferrable
+        self.initially = initially
 
     def __contains__(self, x):
         return self.columns.contains_column(x)
@@ -878,8 +882,8 @@ class Constraint(SchemaItem):
         raise NotImplementedError()
 
 class CheckConstraint(Constraint):
-    def __init__(self, sqltext, name=None):
-        super(CheckConstraint, self).__init__(name)
+    def __init__(self, sqltext, name=None, deferrable=None, initially=None):
+        super(CheckConstraint, self).__init__(name, deferrable, initially)
         self.sqltext = sqltext
 
     def __visit_name__(self):
@@ -899,8 +903,8 @@ class CheckConstraint(Constraint):
 class ForeignKeyConstraint(Constraint):
     """Table-level foreign key constraint, represents a collection of ``ForeignKey`` objects."""
 
-    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False):
-        super(ForeignKeyConstraint, self).__init__(name)
+    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None):
+        super(ForeignKeyConstraint, self).__init__(name, deferrable, initially)
         self.__colnames = columns
         self.__refcolnames = refcolumns
         self.elements = util.OrderedSet()
@@ -930,7 +934,15 @@ class ForeignKeyConstraint(Constraint):
 
 class PrimaryKeyConstraint(Constraint):
     def __init__(self, *columns, **kwargs):
-        super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None))
+        constraint_args = dict(name=kwargs.pop('name', None),
+                               deferrable=kwargs.pop('deferrable', None),
+                               initially=kwargs.pop('initially', None))
+        if kwargs:
+            raise exceptions.ArgumentError(
+                'Unknown PrimaryKeyConstraint argument(s): %s' %
+                ', '.join([repr(x) for x in kwargs.keys()]))
+
+        super(PrimaryKeyConstraint, self).__init__(**constraint_args)
         self.__colnames = list(columns)
 
     def _set_parent(self, table):
@@ -959,7 +971,15 @@ class PrimaryKeyConstraint(Constraint):
 
 class UniqueConstraint(Constraint):
     def __init__(self, *columns, **kwargs):
-        super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None))
+        constraint_args = dict(name=kwargs.pop('name', None),
+                               deferrable=kwargs.pop('deferrable', None),
+                               initially=kwargs.pop('initially', None))
+        if kwargs:
+            raise exceptions.ArgumentError(
+                'Unknown UniqueConstraint argument(s): %s' %
+                ', '.join([repr(x) for x in kwargs.keys()]))
+
+        super(UniqueConstraint, self).__init__(**constraint_args)
         self.__colnames = list(columns)
 
     def _set_parent(self, table):
index 43950a9a6ed396a2da7ca7a31c5bfcceff093ee0..02f6efce184cec2b6e0eebd64dabcaa1e824c173 100644 (file)
@@ -844,9 +844,11 @@ class SchemaGenerator(DDLBase):
             self.append("CONSTRAINT %s " %
                         self.preparer.format_constraint(constraint))
         self.append(" CHECK (%s)" % constraint.sqltext)
+        self.define_constraint_deferrability(constraint)
 
     def visit_column_check_constraint(self, constraint):
         self.append(" CHECK (%s)" % constraint.sqltext)
+        self.define_constraint_deferrability(constraint)
 
     def visit_primary_key_constraint(self, constraint):
         if len(constraint) == 0:
@@ -856,6 +858,7 @@ class SchemaGenerator(DDLBase):
             self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
         self.append("PRIMARY KEY ")
         self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
+        self.define_constraint_deferrability(constraint)
 
     def visit_foreign_key_constraint(self, constraint):
         if constraint.use_alter and self.dialect.supports_alter:
@@ -883,6 +886,7 @@ class SchemaGenerator(DDLBase):
             self.append(" ON DELETE %s" % constraint.ondelete)
         if constraint.onupdate is not None:
             self.append(" ON UPDATE %s" % constraint.onupdate)
+        self.define_constraint_deferrability(constraint)
 
     def visit_unique_constraint(self, constraint):
         self.append(", \n\t")
@@ -890,6 +894,16 @@ class SchemaGenerator(DDLBase):
             self.append("CONSTRAINT %s " %
                         self.preparer.format_constraint(constraint))
         self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
+        self.define_constraint_deferrability(constraint)
+
+    def define_constraint_deferrability(self, constraint):
+        if constraint.deferrable is not None:
+            if constraint.deferrable:
+                self.append(" DEFERRABLE")
+            else:
+                self.append(" NOT DEFERRABLE")
+        if constraint.initially is not None:
+            self.append(" INITIALLY %s" % constraint.initially)
 
     def visit_column(self, column):
         pass
index 22c8bbe8bdc9de97b5c19f80ae7f562e4dd3e900..ee696cd9d47959d5a792249ecc3d3d798e78b7f2 100644 (file)
@@ -1969,9 +1969,12 @@ class RowSwitchTest(ORMTest):
         assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some other t2', 2)]
 
 class TransactionTest(ORMTest):
-    """This is in fact a core test, but currently the only known way
-    to make COMMIT repeatably fail is on postgresql with deferrable FKs"""
-    __only_on__ = 'postgres'
+    __unsupported_on__ = ('mysql', 'mssql')
+
+    # sqlite doesn't have deferrable constraints, but it allows them to
+    # be specified.  it'll raise immediately post-INSERT, instead of at
+    # COMMIT. either way, this test should pass.
+
     def define_tables(self, metadata):
         global t1, T1, t2, T2
 
@@ -1979,17 +1982,24 @@ class TransactionTest(ORMTest):
 
         t1 = Table('t1', metadata,
             Column('id', Integer, primary_key=True))
-        
+
         t2 = Table('t2', metadata,
             Column('id', Integer, primary_key=True),
-            Column('t1_id', Integer))
-        deferred_constraint = DDL("ALTER TABLE t2 ADD CONSTRAINT t2_t1_id_fk FOREIGN KEY (t1_id) "\
-                                  "REFERENCES t1 (id) DEFERRABLE INITIALLY DEFERRED")
-        deferred_constraint.execute_at('after-create', t2)
-        
+            Column('t1_id', Integer,
+                   ForeignKey('t1.id', deferrable=True, initially='deferred')
+                   ))
+
+        # deferred_constraint = \
+        #   DDL("ALTER TABLE t2 ADD CONSTRAINT t2_t1_id_fk FOREIGN KEY (t1_id) "
+        #       "REFERENCES t1 (id) DEFERRABLE INITIALLY DEFERRED")
+        # deferred_constraint.execute_at('after-create', t2)
+        # t1.create()
+        # t2.create()
+        # t2.append_constraint(ForeignKeyConstraint(['t1_id'], ['t1.id']))
+
         class T1(fixtures.Base):
             pass
-        
+
         class T2(fixtures.Base):
             pass
 
@@ -1999,8 +2009,11 @@ class TransactionTest(ORMTest):
     def test_close_transaction_on_commit_fail(self):
         Session = sessionmaker(autoflush=False, transactional=False)
         sess = Session()
-        
+
+        # with a deferred constraint, this fails at COMMIT time instead
+        # of at INSERT time.
         sess.save(T2(t1_id=123))
+
         try:
             sess.flush()
             assert False
@@ -2008,5 +2021,9 @@ class TransactionTest(ORMTest):
             # Flush needs to rollback also when commit fails
             assert sess.transaction is None
 
+        # todo: on 8.3 at least, the failed commit seems to close the cursor?
+        # needs investigation.  leaving in the DDL above now to help verify
+        # that the new deferrable support on FK isn't involved in this issue.
+        t1.bind.engine.dispose()
 if __name__ == "__main__":
     testenv.main()
index 142f1ffba654a9d72f315e4acc3b922a8a277ba2..29fffa7512544f23e96f4599b4ecde0269438796 100644 (file)
@@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import exceptions
 from testlib import *
+from testlib import config, engines
 
 class ConstraintTest(AssertMixin):
 
@@ -201,5 +202,109 @@ class ConstraintTest(AssertMixin):
         ss = events.select().execute().fetchall()
 
 
+class ConstraintCompilationTest(AssertMixin):
+    class accum(object):
+        def __init__(self):
+            self.statements = []
+        def __call__(self, sql, *a, **kw):
+            self.statements.append(sql)
+        def __contains__(self, substring):
+            for s in self.statements:
+                if substring in s:
+                    return True
+            return False
+        def __str__(self):
+            return '\n'.join([repr(x) for x in self.statements])
+        def clear(self):
+            del self.statements[:]
+
+    def setUp(self):
+        self.sql = self.accum()
+        opts = config.db_opts.copy()
+        opts['strategy'] = 'mock'
+        opts['executor'] = self.sql
+        self.engine = engines.testing_engine(options=opts)
+
+
+    def _test_deferrable(self, constraint_factory):
+        meta = MetaData(self.engine)
+        t = Table('tbl', meta,
+                  Column('a', Integer),
+                  Column('b', Integer),
+                  constraint_factory(deferrable=True))
+        t.create()
+        assert 'DEFERRABLE' in self.sql, self.sql
+        assert 'NOT DEFERRABLE' not in self.sql, self.sql
+        self.sql.clear()
+        meta.clear()
+
+        t = Table('tbl', meta,
+                  Column('a', Integer),
+                  Column('b', Integer),
+                  constraint_factory(deferrable=False))
+        t.create()
+        assert 'NOT DEFERRABLE' in self.sql
+        self.sql.clear()
+        meta.clear()
+
+        t = Table('tbl', meta,
+                  Column('a', Integer),
+                  Column('b', Integer),
+                  constraint_factory(deferrable=True, initially='IMMEDIATE'))
+        t.create()
+        assert 'NOT DEFERRABLE' not in self.sql
+        assert 'INITIALLY IMMEDIATE' in self.sql
+        self.sql.clear()
+        meta.clear()
+
+        t = Table('tbl', meta,
+                  Column('a', Integer),
+                  Column('b', Integer),
+                  constraint_factory(deferrable=True, initially='DEFERRED'))
+        t.create()
+
+        assert 'NOT DEFERRABLE' not in self.sql
+        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+    def test_deferrable_pk(self):
+        factory = lambda **kw: PrimaryKeyConstraint('a', **kw)
+        self._test_deferrable(factory)
+
+    def test_deferrable_table_fk(self):
+        factory = lambda **kw: ForeignKeyConstraint(['b'], ['tbl.a'], **kw)
+        self._test_deferrable(factory)
+
+    def test_deferrable_column_fk(self):
+        meta = MetaData(self.engine)
+        t = Table('tbl', meta,
+                  Column('a', Integer),
+                  Column('b', Integer,
+                         ForeignKey('tbl.a', deferrable=True,
+                                    initially='DEFERRED')))
+        t.create()
+        assert 'DEFERRABLE' in self.sql, self.sql
+        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+    def test_deferrable_unique(self):
+        factory = lambda **kw: UniqueConstraint('b', **kw)
+        self._test_deferrable(factory)
+
+    def test_deferrable_table_check(self):
+        factory = lambda **kw: CheckConstraint('a < b', **kw)
+        self._test_deferrable(factory)
+
+    def test_deferrable_column_check(self):
+        meta = MetaData(self.engine)
+        t = Table('tbl', meta,
+                  Column('a', Integer),
+                  Column('b', Integer,
+                         CheckConstraint('a < b',
+                                         deferrable=True,
+                                         initially='DEFERRED')))
+        t.create()
+        assert 'DEFERRABLE' in self.sql, self.sql
+        assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+
 if __name__ == "__main__":
     testenv.main()