]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- expand the check to determine if a selectable column is embedded
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Feb 2012 22:47:59 +0000 (17:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Feb 2012 22:47:59 +0000 (17:47 -0500)
in the corresponding selectable to take into account clones
of the target column.  fixes [ticket:2419]
- have _make_proxy() copy out the _is_clone_of attribute on the
new column so that even more corresponding_column() checks
work as expected for cloned elements.
- add a new test fixture so that mapped tests can be specified
using declarative.

CHANGES
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
test/lib/fixtures.py
test/orm/inheritance/test_assorted_poly.py
test/sql/test_generative.py

diff --git a/CHANGES b/CHANGES
index 867ed011860e062b9f3e9c2edaa9bc56d9741f1e..418fc6b23c156a49af6d1dc292de25e5b086d331 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -17,6 +17,12 @@ CHANGES
     @collection.internally_instrumented.  
     [ticket:2406]
 
+  - [bug] Fixed bug whereby SQL adaption mechanics
+    would fail in a very nested scenario involving
+    joined-inheritance, joinedload(), limit(), and a
+    derived function in the columns clause.  
+    [ticket:2419]
+
   - [feature] Added the ability to query for
     Table-bound column names when using 
     query(sometable).filter_by(colname=value).  
index 3ae0053c3344eee2509e800b455a9ac36d64e7e6..154e18e5f464913eb62070d15bd639811e7bf0e4 100644 (file)
@@ -1122,6 +1122,8 @@ class Column(SchemaItem, expression.ColumnClause):
 
         c.table = selectable
         selectable._columns.add(c)
+        if selectable._is_clone_of is not None:
+            c._is_clone_of = selectable._is_clone_of.columns[c.name]
         if self.primary_key:
             selectable.primary_key.add(c)
         c.dispatch.after_parent_attach(c, selectable)
index b11e5ad429845bf22df7d2d2d104054ecdfa9e0d..4b61e6dc334a441bea0928273d5505d7e46dec2e 100644 (file)
@@ -1508,6 +1508,7 @@ class ClauseElement(Visitable):
     supports_execution = False
     _from_objects = []
     bind = None
+    _is_clone_of = None
 
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
@@ -1556,7 +1557,7 @@ class ClauseElement(Visitable):
         f = self
         while f is not None:
             s.add(f)
-            f = getattr(f, '_is_clone_of', None)
+            f = f._is_clone_of
         return s
 
     def __getstate__(self):
@@ -2158,6 +2159,9 @@ class ColumnElement(ClauseElement, _CompareMixin):
                             type_=getattr(self,
                           'type', None))
         co.proxies = [self]
+        if selectable._is_clone_of is not None:
+            co._is_clone_of = \
+                selectable._is_clone_of.columns[key]
         selectable._columns[key] = co
         return co
 
@@ -2466,6 +2470,13 @@ class FromClause(Selectable):
 
         """
 
+        def embedded(expanded_proxy_set, target_set):
+            for t in target_set.difference(expanded_proxy_set):
+                if not set(_expand_cloned([t])
+                            ).intersection(expanded_proxy_set):
+                    return False
+            return True
+
         # dont dig around if the column is locally present
         if self.c.contains_column(column):
             return column
@@ -2473,10 +2484,10 @@ class FromClause(Selectable):
         target_set = column.proxy_set
         cols = self.c
         for c in cols:
-            i = target_set.intersection(itertools.chain(*[p._cloned_set
-                    for p in c.proxy_set]))
+            expanded_proxy_set = set(_expand_cloned(c.proxy_set))
+            i = target_set.intersection(expanded_proxy_set)
             if i and (not require_embedded
-                      or c.proxy_set.issuperset(target_set)):
+                      or embedded(expanded_proxy_set, target_set)):
                 if col is None:
 
                     # no corresponding column yet, pick this one.
@@ -4073,6 +4084,9 @@ class ColumnClause(_Immutable, ColumnElement):
                     is_literal=is_literal
                 )
         c.proxies = [self]
+        if selectable._is_clone_of is not None:
+            c._is_clone_of = \
+                selectable._is_clone_of.columns[c.name]
 
         if attach:
             selectable._columns[c.name] = c
index 6714107bdc4d759a9ac9b704134cb2605daf91c6..03116fbc113356fbd16c051a8652c49600aac2ca 100644 (file)
@@ -4,6 +4,7 @@ from test.lib.engines import drop_all_tables
 import sys
 import sqlalchemy as sa
 from test.lib.entities import BasicEntity, ComparableEntity
+from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
 
 class TestBase(object):
     # A sequence of database names to always run, regardless of the
@@ -228,8 +229,7 @@ class MappedTest(_ORMTest, TablesTest, testing.AssertsExecutionResults):
 
     @classmethod
     def teardown_class(cls):
-        cls.classes.clear()
-        _ORMTest.teardown_class()
+        cls._teardown_once_class()
         cls._teardown_once_metadata_bind()
 
     def setup(self):
@@ -242,6 +242,12 @@ class MappedTest(_ORMTest, TablesTest, testing.AssertsExecutionResults):
         self._teardown_each_mappers()
         self._teardown_each_tables()
 
+    @classmethod
+    def _teardown_once_class(cls):
+        cls.classes.clear()
+        _ORMTest.teardown_class()
+
+
     @classmethod
     def _setup_once_classes(cls):
         if cls.run_setup_classes == 'once':
@@ -269,6 +275,7 @@ class MappedTest(_ORMTest, TablesTest, testing.AssertsExecutionResults):
                 cls_registry[classname] = cls
                 return type.__init__(cls, classname, bases, dict_)
 
+
         class _Base(object):
             __metaclass__ = FindFixture
         class Basic(BasicEntity, _Base):
@@ -294,3 +301,35 @@ class MappedTest(_ORMTest, TablesTest, testing.AssertsExecutionResults):
     def setup_mappers(cls):
         pass
 
+class DeclarativeMappedTest(MappedTest):
+    declarative_meta = None
+
+    @classmethod
+    def setup_class(cls):
+        if cls.declarative_meta is None:
+            cls.declarative_meta = sa.MetaData()
+
+        super(DeclarativeMappedTest, cls).setup_class()
+
+    @classmethod
+    def _teardown_once_class(cls):
+        if cls.declarative_meta.tables:
+            cls.declarative_meta.drop_all(testing.db)
+        super(DeclarativeMappedTest, cls)._teardown_once_class()
+
+    @classmethod
+    def _with_register_classes(cls, fn):
+        cls_registry = cls.classes
+        class FindFixtureDeclarative(DeclarativeMeta):
+            def __init__(cls, classname, bases, dict_):
+                cls_registry[classname] = cls
+                return DeclarativeMeta.__init__(
+                        cls, classname, bases, dict_)
+        _DeclBase = declarative_base(metadata=cls.declarative_meta, 
+                            metaclass=FindFixtureDeclarative)
+        class DeclarativeBasic(BasicEntity):
+            pass
+        cls.DeclarativeBasic = _DeclBase
+        fn()
+        if cls.declarative_meta.tables:
+            cls.declarative_meta.create_all(testing.db)
index 0a10071ea51e223fdc3d3c95978146ae10a86f1c..4fa1034adffe4416a4f7bc52807e47f5168e41f7 100644 (file)
@@ -1,5 +1,5 @@
-"""Very old inheritance-related tests.
-
+"""Miscellaneous inheritance-related tests, many very old.
+These are generally tests derived from specific user issues.
 
 """
 
@@ -1451,4 +1451,65 @@ class JoinedInhAdjacencyTest(fixtures.MappedTest):
             }
         )
         assert Dude.supervisor.property.direction is MANYTOONE
-        self._dude_roundtrip()
\ No newline at end of file
+        self._dude_roundtrip()
+
+
+class Ticket2419Test(fixtures.DeclarativeMappedTest):
+    """Test [ticket:2419]'s test case."""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+        class A(Base):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+
+        class B(Base):
+            __tablename__ = "b"
+
+            id = Column(Integer, primary_key=True)
+            ds = relationship("D")
+            es = relationship("E")
+
+        class C(A):
+            __tablename__ = "c"
+
+            id = Column(Integer, ForeignKey('a.id'), primary_key=True)
+            b_id = Column(Integer, ForeignKey('b.id'))
+            b = relationship("B", primaryjoin=b_id==B.id)
+
+        class D(Base):
+            __tablename__ = "d"
+
+            id = Column(Integer, primary_key=True)
+            b_id = Column(Integer, ForeignKey('b.id'))
+
+        class E(Base):
+            __tablename__ = 'e'
+            id = Column(Integer, primary_key=True)
+            b_id = Column(Integer, ForeignKey('b.id'))
+
+    def test_join_w_eager_w_any(self):
+        A, B, C, D, E = self.classes.A, self.classes.B, \
+                        self.classes.C, self.classes.D, \
+                        self.classes.E
+        s = Session(testing.db)
+
+        b = B(ds=[D()])
+        s.add_all([
+            C(
+                b=b
+            )
+
+        ])
+
+        s.commit()
+
+        q = s.query(B, B.ds.any(D.id==1)).options(joinedload_all("es"))
+        q = q.join(C, C.b_id==B.id)
+        q = q.limit(5)
+        eq_(
+            q.all(),
+            [(b, True)]
+        )
index f9333dbf5329b0092474c22716944721636b4f65..98e783ede9e6a23275d0650e87d7fd0182511c25 100644 (file)
@@ -5,7 +5,7 @@ from test.lib import *
 from sqlalchemy.sql.visitors import *
 from sqlalchemy import util, exc
 from sqlalchemy.sql import util as sql_util
-from test.lib.testing import eq_, assert_raises
+from test.lib.testing import eq_, ne_, assert_raises
 
 class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
     """test ClauseVisitor's traversal, particularly its 
@@ -173,7 +173,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
 
     @classmethod
     def setup_class(cls):
-        global t1, t2
+        global t1, t2, t3
         t1 = table("table1",
             column("col1"),
             column("col2"),
@@ -184,6 +184,10 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             column("col2"),
             column("col3"),
             )
+        t3 = Table('table3', MetaData(), 
+            Column('col1', Integer),
+            Column('col2', Integer)
+        )
 
     def test_binary(self):
         clause = t1.c.col2 == t2.c.col2
@@ -242,6 +246,80 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             str(f)
         )
 
+    def test_aliased_cloned_column_adapt_inner(self):
+        clause = select([t1.c.col1, func.foo(t1.c.col2).label('foo')])
+
+        aliased1 = select([clause.c.col1, clause.c.foo])
+        aliased2 = clause
+        aliased2.c.col1, aliased2.c.foo
+        aliased3 = cloned_traverse(aliased2, {}, {})
+
+        # fixed by [ticket:2419].   the inside columns
+        # on aliased3 have _is_clone_of pointers to those of
+        # aliased2.  corresponding_column checks these 
+        # now.
+        adapter = sql_util.ColumnAdapter(aliased1)
+        f1 = select([
+            adapter.columns[c]
+            for c in aliased2._raw_columns
+        ])
+        f2 = select([
+            adapter.columns[c]
+            for c in aliased3._raw_columns
+        ])
+        eq_(
+            str(f1), str(f2)
+        )
+
+    def test_aliased_cloned_column_adapt_exported(self):
+        clause = select([t1.c.col1, func.foo(t1.c.col2).label('foo')])
+
+        aliased1 = select([clause.c.col1, clause.c.foo])
+        aliased2 = clause
+        aliased2.c.col1, aliased2.c.foo
+        aliased3 = cloned_traverse(aliased2, {}, {})
+
+        # also fixed by [ticket:2419].  When we look at the
+        # *outside* columns of aliased3, they previously did not 
+        # have an _is_clone_of pointer.   But we now modified _make_proxy
+        # to assign this.
+        adapter = sql_util.ColumnAdapter(aliased1)
+        f1 = select([
+            adapter.columns[c]
+            for c in aliased2.c
+        ])
+        f2 = select([
+            adapter.columns[c]
+            for c in aliased3.c
+        ])
+        eq_(
+            str(f1), str(f2)
+        )
+
+    def test_aliased_cloned_schema_column_adapt_exported(self):
+        clause = select([t3.c.col1, func.foo(t3.c.col2).label('foo')])
+
+        aliased1 = select([clause.c.col1, clause.c.foo])
+        aliased2 = clause
+        aliased2.c.col1, aliased2.c.foo
+        aliased3 = cloned_traverse(aliased2, {}, {})
+
+        # also fixed by [ticket:2419].  When we look at the
+        # *outside* columns of aliased3, they previously did not 
+        # have an _is_clone_of pointer.   But we now modified _make_proxy
+        # to assign this.
+        adapter = sql_util.ColumnAdapter(aliased1)
+        f1 = select([
+            adapter.columns[c]
+            for c in aliased2.c
+        ])
+        f2 = select([
+            adapter.columns[c]
+            for c in aliased3.c
+        ])
+        eq_(
+            str(f1), str(f2)
+        )
 
     def test_text(self):
         clause = text(