From: Mike Bayer Date: Fri, 25 Feb 2011 18:20:43 +0000 (-0500) Subject: - establish an "insert" option for events to control ordering if needed (not needed... X-Git-Tag: rel_0_7b3~45 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a7f766d7c7fd6c53eb0019e32569e915b3f31eb4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - establish an "insert" option for events to control ordering if needed (not needed yet tho) - render foreign key constraints in the order in which they were cerated --- diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index 7c2b49ce8a..b2e5cd00f2 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -143,8 +143,11 @@ class Events(object): return None @classmethod - def _listen(cls, target, identifier, fn, propagate=False): - getattr(target.dispatch, identifier).append(fn, target, propagate) + def _listen(cls, target, identifier, fn, propagate=False, insert=False): + if insert: + getattr(target.dispatch, identifier).insert(fn, target, propagate) + else: + getattr(target.dispatch, identifier).append(fn, target, propagate) @classmethod def _remove(cls, target, identifier, fn): @@ -164,6 +167,16 @@ class _DispatchDescriptor(object): self.__doc__ = fn.__doc__ self._clslevel = util.defaultdict(list) + def insert(self, obj, target, propagate): + assert isinstance(target, type), \ + "Class-level Event targets must be classes." + + stack = [target] + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + self._clslevel[cls].insert(0, obj) + def append(self, obj, target, propagate): assert isinstance(target, type), \ "Class-level Event targets must be classes." @@ -260,6 +273,12 @@ class _ListenerCollection(object): and not only_propagate or l in self.propagate ]) + def insert(self, obj, target, propagate): + if obj not in self.listeners: + self.listeners.insert(0, obj) + if propagate: + self.propagate.add(obj) + def append(self, obj, target, propagate): if obj not in self.listeners: self.listeners.append(obj) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index e6b970291c..80d6018b84 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -265,6 +265,12 @@ class Table(SchemaItem, expression.TableClause): # allow user-overrides self._init_items(*args) + @property + def _sorted_constraints(self): + """Return the set of constraints as a list, sorted by creation order.""" + + return sorted(self.constraints, key=lambda c:c._creation_order) + def _init_existing(self, *args, **kwargs): autoload = kwargs.pop('autoload', False) autoload_with = kwargs.pop('autoload_with', None) @@ -1595,6 +1601,7 @@ class Constraint(SchemaItem): self.deferrable = deferrable self.initially = initially self._create_rule = _create_rule + util.set_creation_order(self) @property def table(self): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 781072dd03..d6a020bdcc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1272,7 +1272,7 @@ class DDLCompiler(engine.Compiled): if table.primary_key: constraints.append(table.primary_key) - constraints.extend([c for c in table.constraints + constraints.extend([c for c in table._sorted_constraints if c is not table.primary_key]) return ", \n\t".join(p for p in diff --git a/test/base/test_events.py b/test/base/test_events.py index e894a1f74f..f699a66f20 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -70,6 +70,25 @@ class TestEvents(TestBase): eq_(len(Target().dispatch.event_one), 2) eq_(len(t1.dispatch.event_one), 3) + def test_append_vs_insert(self): + def listen_one(x, y): + pass + + def listen_two(x, y): + pass + + def listen_three(x, y): + pass + + event.listen(Target, "event_one", listen_one) + event.listen(Target, "event_one", listen_two) + event.listen(Target, "event_one", listen_three, insert=True) + + eq_( + list(Target().dispatch.event_one), + [listen_three, listen_one, listen_two] + ) + class TestAcceptTargets(TestBase): """Test default target acceptance.""" diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 1c13a0ec7a..c5433fa9c2 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -322,6 +322,28 @@ class ConstraintCompilationTest(TestBase, AssertsCompiledSQL): factory = lambda **kw: CheckConstraint('a < b', **kw) self._test_deferrable(factory) + def test_multiple(self): + m = MetaData() + foo = Table("foo", m, + Column('id', Integer, primary_key=True), + Column('bar', Integer, primary_key=True) + ) + tb = Table("some_table", m, + Column('id', Integer, primary_key=True), + Column('foo_id', Integer, ForeignKey('foo.id')), + Column('foo_bar', Integer, ForeignKey('foo.bar')), + ) + self.assert_compile( + schema.CreateTable(tb), + "CREATE TABLE some_table (" + "id INTEGER NOT NULL, " + "foo_id INTEGER, " + "foo_bar INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(foo_id) REFERENCES foo (id), " + "FOREIGN KEY(foo_bar) REFERENCES foo (bar))" + ) + def test_deferrable_column_check(self): t = Table('tbl', MetaData(), Column('a', Integer),