]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in unit of work whereby a joined-inheritance
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Apr 2013 17:54:09 +0000 (13:54 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Apr 2013 17:54:09 +0000 (13:54 -0400)
    subclass could insert the row for the "sub" table
    before the parent table, if the two tables had no
    ForeignKey constraints set up between them. [ticket:2689]
- fixed glitch in assertsql.py regarding CompiledSQL + AllOf +
multiple params

doc/build/changelog/changelog_07.rst
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql/util.py
test/lib/assertsql.py
test/orm/inheritance/test_basic.py

index 0e7809d183c37345b2917d0b5b0992081018564c..f1b30a55a79dce30a99f545b1d0c029b7329c0c1 100644 (file)
@@ -6,6 +6,15 @@
 .. changelog::
     :version: 0.7.11
 
+    .. change::
+      :tags: bug, orm
+      :tickets: 2689
+
+    Fixed bug in unit of work whereby a joined-inheritance
+    subclass could insert the row for the "sub" table
+    before the parent table, if the two tables had no
+    ForeignKey constraints set up between them.
+
     .. change::
         :tags: feature, postgresql
         :tickets: 2676
index 36771de234076d1f72707b8a5b4bf2fd55a68168..fca1d438793418afb017b162464b7020b7e01e53 100644 (file)
@@ -1853,11 +1853,24 @@ class Mapper(object):
     @_memoized_configured_property
     def _sorted_tables(self):
         table_to_mapper = {}
+        table_to_mapper_08 = {}
         for mapper in self.base_mapper.self_and_descendants:
             for t in mapper.tables:
                 table_to_mapper[t] = mapper
-
-        sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys())
+                # emulate 0.8's approach to fix #2689
+                table_to_mapper_08.setdefault(t, mapper)
+
+        extra_dependencies = []
+        for table, mapper in table_to_mapper_08.items():
+            super_ = mapper.inherits
+            if super_:
+                extra_dependencies.extend([
+                    (super_table, table)
+                    for super_table in super_.tables
+                    ])
+
+        sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys(),
+                            extra_dependencies=extra_dependencies)
         ret = util.OrderedDict()
         for t in sorted_:
             ret[t] = table_to_mapper[t]
index 0a00674cc83048a25582bf8c58b0545f8f8d3c12..bf4235ad7f04d3daeaeb6033ea7ee986b678f34e 100644 (file)
@@ -12,11 +12,14 @@ from collections import deque
 
 """Utility functions that build upon SQL and Schema constructs."""
 
-def sort_tables(tables):
+def sort_tables(tables, extra_dependencies=None):
     """sort a collection of Table objects in order of their foreign-key dependency."""
 
     tables = list(tables)
     tuples = []
+    if extra_dependencies:
+        tuples.extend(extra_dependencies)
+
     def visit_foreign_key(fkey):
         if fkey.use_alter:
             return
@@ -28,8 +31,8 @@ def sort_tables(tables):
 
     for table in tables:
         visitors.traverse(table,
-                            {'schema_visitor':True},
-                            {'foreign_key':visit_foreign_key})
+                            {'schema_visitor': True},
+                            {'foreign_key': visit_foreign_key})
 
         tuples.extend(
             [parent, table] for parent in table._extra_dependencies
index 897f4b3b1d2df6a07bacce27f81e5f9ef51df150..8b1f2473c3d01d85991c229bfcb465efa96881b4 100644 (file)
@@ -167,6 +167,8 @@ class CompiledSQL(SQLMatchRule):
                 params = self.params
             if not isinstance(params, list):
                 params = [params]
+            else:
+                params = list(params)
             all_params = list(params)
             all_received = list(_received_parameters)
             while params:
index a8cfe5e9a6d12407aa71041db76fd9b6c2cfbda5..76168e71512c51b7a4706d72be82f4f3878bf765 100644 (file)
@@ -915,6 +915,73 @@ class FlushTest(fixtures.MappedTest):
         sess.flush()
         assert user_roles.count().scalar() == 1
 
+class JoinedNoFKSortingTest(fixtures.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("a", metadata,
+                Column('id', Integer, primary_key=True,
+                    test_needs_autoincrement=True)
+            )
+        Table("b", metadata,
+                Column('id', Integer, primary_key=True)
+            )
+        Table("c", metadata,
+                Column('id', Integer, primary_key=True)
+            )
+
+    @classmethod
+    def setup_classes(cls):
+        class A(cls.Basic):
+            pass
+        class B(A):
+            pass
+        class C(A):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        A, B, C = cls.classes.A, cls.classes.B, cls.classes.C
+        mapper(A, cls.tables.a)
+        mapper(B, cls.tables.b, inherits=A,
+                    inherit_condition=cls.tables.a.c.id == cls.tables.b.c.id)
+        mapper(C, cls.tables.c, inherits=A,
+                    inherit_condition=cls.tables.a.c.id == cls.tables.c.c.id)
+
+    def test_ordering(self):
+        B, C = self.classes.B, self.classes.C
+        sess = Session()
+        sess.add_all([B(), C(), B(), C()])
+        self.assert_sql_execution(
+                testing.db,
+                sess.flush,
+                CompiledSQL(
+                    "INSERT INTO a () VALUES ()",
+                    {}
+                ),
+                CompiledSQL(
+                    "INSERT INTO a () VALUES ()",
+                    {}
+                ),
+                CompiledSQL(
+                    "INSERT INTO a () VALUES ()",
+                    {}
+                ),
+                CompiledSQL(
+                    "INSERT INTO a () VALUES ()",
+                    {}
+                ),
+                AllOf(
+                    CompiledSQL(
+                        "INSERT INTO b (id) VALUES (:id)",
+                        [{"id": 1}, {"id": 3}]
+                    ),
+                    CompiledSQL(
+                        "INSERT INTO c (id) VALUES (:id)",
+                        [{"id": 2}, {"id": 4}]
+                    )
+                )
+        )
+
 class VersioningTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):