]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- establish an "insert" option for events to control ordering if needed (not needed...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Feb 2011 18:20:43 +0000 (13:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Feb 2011 18:20:43 +0000 (13:20 -0500)
- render foreign key constraints in the order in which they were cerated

lib/sqlalchemy/event.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
test/base/test_events.py
test/sql/test_constraints.py

index 7c2b49ce8a0001c9156cedf5f80d08984a27f233..b2e5cd00f256306cb2f2889e5024186b71cfb139 100644 (file)
@@ -143,8 +143,11 @@ class Events(object):
             return None
 
     @classmethod
-    def _listen(cls, target, identifier, fn, propagate=False):
-        getattr(target.dispatch, identifier).append(fn, target, propagate)
+    def _listen(cls, target, identifier, fn, propagate=False, insert=False):
+        if insert:
+            getattr(target.dispatch, identifier).insert(fn, target, propagate)
+        else:
+            getattr(target.dispatch, identifier).append(fn, target, propagate)
 
     @classmethod
     def _remove(cls, target, identifier, fn):
@@ -164,6 +167,16 @@ class _DispatchDescriptor(object):
         self.__doc__ = fn.__doc__
         self._clslevel = util.defaultdict(list)
 
+    def insert(self, obj, target, propagate):
+        assert isinstance(target, type), \
+                "Class-level Event targets must be classes."
+
+        stack = [target]
+        while stack:
+            cls = stack.pop(0)
+            stack.extend(cls.__subclasses__())
+            self._clslevel[cls].insert(0, obj)
+
     def append(self, obj, target, propagate):
         assert isinstance(target, type), \
                 "Class-level Event targets must be classes."
@@ -260,6 +273,12 @@ class _ListenerCollection(object):
                                 and not only_propagate or l in self.propagate
                                 ])
 
+    def insert(self, obj, target, propagate):
+        if obj not in self.listeners:
+            self.listeners.insert(0, obj)
+            if propagate:
+                self.propagate.add(obj)
+
     def append(self, obj, target, propagate):
         if obj not in self.listeners:
             self.listeners.append(obj)
index e6b970291cbb5ab2062b8c691b30d625a8add51a..80d6018b843e4accba1aa17859f5d6f944c9d3c5 100644 (file)
@@ -265,6 +265,12 @@ class Table(SchemaItem, expression.TableClause):
         # allow user-overrides
         self._init_items(*args)
 
+    @property
+    def _sorted_constraints(self):
+        """Return the set of constraints as a list, sorted by creation order."""
+
+        return sorted(self.constraints, key=lambda c:c._creation_order)
+
     def _init_existing(self, *args, **kwargs):
         autoload = kwargs.pop('autoload', False)
         autoload_with = kwargs.pop('autoload_with', None)
@@ -1595,6 +1601,7 @@ class Constraint(SchemaItem):
         self.deferrable = deferrable
         self.initially = initially
         self._create_rule = _create_rule
+        util.set_creation_order(self)
 
     @property
     def table(self):
index 781072dd03b953b41fe8823832049e06fdcafe8e..d6a020bdccd19a18d8be1f2513a8e101445a1e52 100644 (file)
@@ -1272,7 +1272,7 @@ class DDLCompiler(engine.Compiled):
         if table.primary_key:
             constraints.append(table.primary_key)
 
-        constraints.extend([c for c in table.constraints 
+        constraints.extend([c for c in table._sorted_constraints 
                                 if c is not table.primary_key])
 
         return ", \n\t".join(p for p in
index e894a1f74fcbb7e6ad1ac5c9c1d4f46bdcb6299d..f699a66f202beaf9c0ddd9cc88ed6bd1f253c298 100644 (file)
@@ -70,6 +70,25 @@ class TestEvents(TestBase):
         eq_(len(Target().dispatch.event_one), 2)
         eq_(len(t1.dispatch.event_one), 3)
 
+    def test_append_vs_insert(self):
+        def listen_one(x, y):
+            pass
+
+        def listen_two(x, y):
+            pass
+
+        def listen_three(x, y):
+            pass
+
+        event.listen(Target, "event_one", listen_one)
+        event.listen(Target, "event_one", listen_two)
+        event.listen(Target, "event_one", listen_three, insert=True)
+
+        eq_(
+            list(Target().dispatch.event_one),
+            [listen_three, listen_one, listen_two]
+        )
+
 class TestAcceptTargets(TestBase):
     """Test default target acceptance."""
 
index 1c13a0ec7adc992f389ace7e4372c81d4e8fae94..c5433fa9c2e2a082df94fe0acc5f2576f885b503 100644 (file)
@@ -322,6 +322,28 @@ class ConstraintCompilationTest(TestBase, AssertsCompiledSQL):
         factory = lambda **kw: CheckConstraint('a < b', **kw)
         self._test_deferrable(factory)
 
+    def test_multiple(self):
+        m = MetaData()
+        foo = Table("foo", m, 
+            Column('id', Integer, primary_key=True),
+            Column('bar', Integer, primary_key=True)
+        )
+        tb = Table("some_table", m,
+        Column('id', Integer, primary_key=True),
+        Column('foo_id', Integer, ForeignKey('foo.id')),
+        Column('foo_bar', Integer, ForeignKey('foo.bar')),
+        )
+        self.assert_compile(
+            schema.CreateTable(tb),
+            "CREATE TABLE some_table ("
+                "id INTEGER NOT NULL, "
+                "foo_id INTEGER, "
+                "foo_bar INTEGER, "
+                "PRIMARY KEY (id), "
+                "FOREIGN KEY(foo_id) REFERENCES foo (id), "
+                "FOREIGN KEY(foo_bar) REFERENCES foo (bar))"
+        )
+
     def test_deferrable_column_check(self):
         t = Table('tbl', MetaData(),
                   Column('a', Integer),