]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The :class:`.DeferredReflection` class has been enhanced to provide
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Dec 2013 18:46:41 +0000 (13:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Dec 2013 18:46:41 +0000 (13:46 -0500)
automatic reflection support for the "secondary" table referred
to by a :func:`.relationship`.   "secondary", when specified
either as a string table name, or as a :class:`.Table` object with
only a name and :class:`.MetaData` object will also be included
in the reflection process when :meth:`.DeferredReflection.prepare`
is called. [ticket:2865]
- clsregistry._resolver() now uses a stateful _class_resolver()
class in order to handle the work of mapping strings to
objects.   This is to provide for simpler extensibility, namely
a ._resolvers collection of ad-hoc name resolution functions;
the DeferredReflection class adds its own resolver here in order
to handle relationship(secondary) names which generate new
Table objects.

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/ext/declarative/api.py
lib/sqlalchemy/ext/declarative/clsregistry.py
test/ext/declarative/test_reflection.py

index 17116c2c4a653bc9337374c3034e3d6347a10d0c..367fa1df9b43299fbae21c93ae5ab2169b1282ab 100644 (file)
 .. changelog::
     :version: 0.9.0b2
 
+    .. change::
+        :tags: bug, orm, declarative
+        :tickets: 2865
+
+        The :class:`.DeferredReflection` class has been enhanced to provide
+        automatic reflection support for the "secondary" table referred
+        to by a :func:`.relationship`.   "secondary", when specified
+        either as a string table name, or as a :class:`.Table` object with
+        only a name and :class:`.MetaData` object will also be included
+        in the reflection process when :meth:`.DeferredReflection.prepare`
+        is called.
+
     .. change::
         :tags: feature, orm, backrefs
         :tickets: 1535
index 1cb653a233527162121cce8a17eebf5993d06987..64bf7fd9f3c8d22045a7508ccff41abd7bea19ce 100644 (file)
@@ -9,15 +9,17 @@
 from ...schema import Table, MetaData
 from ...orm import synonym as _orm_synonym, mapper,\
                                 comparable_property,\
-                                interfaces
+                                interfaces, properties
 from ...orm.util import polymorphic_union
 from ...orm.base import _mapper_or_none
+from ...util import compat
 from ... import exc
 import weakref
 
 from .base import _as_declarative, \
                 _declarative_constructor,\
                 _MapperConfig, _add_attribute
+from .clsregistry import _class_resolver
 
 
 def instrument_declarative(cls, registry, metadata):
@@ -465,11 +467,31 @@ class DeferredReflection(object):
     def prepare(cls, engine):
         """Reflect all :class:`.Table` objects for all current
         :class:`.DeferredReflection` subclasses"""
+
         to_map = [m for m in _MapperConfig.configs.values()
                     if issubclass(m.cls, cls)]
         for thingy in to_map:
             cls._sa_decl_prepare(thingy.local_table, engine)
             thingy.map()
+            mapper = thingy.cls.__mapper__
+            metadata = mapper.class_.metadata
+            for rel in mapper._props.values():
+                if isinstance(rel, properties.RelationshipProperty) and \
+                    rel.secondary is not None:
+                    if isinstance(rel.secondary, Table):
+                        cls._sa_decl_prepare(rel.secondary, engine)
+                    elif isinstance(rel.secondary, _class_resolver):
+                        rel.secondary._resolvers += (
+                            cls._sa_deferred_table_resolver(engine, metadata),
+                        )
+
+    @classmethod
+    def _sa_deferred_table_resolver(cls, engine, metadata):
+        def _resolve(key):
+            t1 = Table(key, metadata)
+            cls._sa_decl_prepare(t1, engine)
+            return t1
+        return _resolve
 
     @classmethod
     def _sa_decl_prepare(cls, local_table, engine):
index 8fef8f1bcb51461310f4c09647d4ff91f560163d..04567b32c4c855fcfeb6f291cb07c80e6a692fb3 100644 (file)
@@ -225,47 +225,62 @@ def _determine_container(key, value):
     return _GetColumns(value)
 
 
-def _resolver(cls, prop):
-    def resolve_arg(arg):
-        import sqlalchemy
-        from sqlalchemy.orm import foreign, remote
-
-        fallback = sqlalchemy.__dict__.copy()
-        fallback.update({'foreign': foreign, 'remote': remote})
-
-        def access_cls(key):
-            if key in cls._decl_class_registry:
-                return _determine_container(key, cls._decl_class_registry[key])
-            elif key in cls.metadata.tables:
-                return cls.metadata.tables[key]
-            elif key in cls.metadata._schemas:
-                return _GetTable(key, cls.metadata)
-            elif '_sa_module_registry' in cls._decl_class_registry and \
-                key in cls._decl_class_registry['_sa_module_registry']:
-                registry = cls._decl_class_registry['_sa_module_registry']
-                return registry.resolve_attr(key)
+class _class_resolver(object):
+    def __init__(self, cls, prop, fallback, arg):
+        self.cls = cls
+        self.prop = prop
+        self.arg = self._declarative_arg = arg
+        self.fallback = fallback
+        self._dict = util.PopulateDict(self._access_cls)
+        self._resolvers = ()
+
+    def _access_cls(self, key):
+        cls = self.cls
+        if key in cls._decl_class_registry:
+            return _determine_container(key, cls._decl_class_registry[key])
+        elif key in cls.metadata.tables:
+            return cls.metadata.tables[key]
+        elif key in cls.metadata._schemas:
+            return _GetTable(key, cls.metadata)
+        elif '_sa_module_registry' in cls._decl_class_registry and \
+            key in cls._decl_class_registry['_sa_module_registry']:
+            registry = cls._decl_class_registry['_sa_module_registry']
+            return registry.resolve_attr(key)
+        elif self._resolvers:
+            for resolv in self._resolvers:
+                value = resolv(key)
+                if value is not None:
+                    return value
+
+        return self.fallback[key]
+
+    def __call__(self):
+        try:
+            x = eval(self.arg, globals(), self._dict)
+
+            if isinstance(x, _GetColumns):
+                return x.cls
             else:
-                return fallback[key]
+                return x
+        except NameError as n:
+            raise exc.InvalidRequestError(
+                "When initializing mapper %s, expression %r failed to "
+                "locate a name (%r). If this is a class name, consider "
+                "adding this relationship() to the %r class after "
+                "both dependent classes have been defined." %
+                (self.prop.parent, self.arg, n.args[0], self.cls)
+            )
 
-        d = util.PopulateDict(access_cls)
 
-        def return_cls():
-            try:
-                x = eval(arg, globals(), d)
+def _resolver(cls, prop):
+    import sqlalchemy
+    from sqlalchemy.orm import foreign, remote
 
-                if isinstance(x, _GetColumns):
-                    return x.cls
-                else:
-                    return x
-            except NameError as n:
-                raise exc.InvalidRequestError(
-                    "When initializing mapper %s, expression %r failed to "
-                    "locate a name (%r). If this is a class name, consider "
-                    "adding this relationship() to the %r class after "
-                    "both dependent classes have been defined." %
-                    (prop.parent, arg, n.args[0], cls)
-                )
-        return return_cls
+    fallback = sqlalchemy.__dict__.copy()
+    fallback.update({'foreign': foreign, 'remote': remote})
+
+    def resolve_arg(arg):
+        return _class_resolver(cls, prop, fallback, arg)
     return resolve_arg
 
 
index 013439f9302d7bd67c453ae267c3ff268f369522..26496f1ada061a52c99d6e9e02d59feea46a58c4 100644 (file)
@@ -47,9 +47,8 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
             test_needs_fk=True,
             )
 
-    def test_basic(self):
-        meta = MetaData(testing.db)
 
+    def test_basic(self):
         class User(Base, fixtures.ComparableEntity):
 
             __tablename__ = 'users'
@@ -80,8 +79,6 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
         eq_(a1.user, User(name='u1'))
 
     def test_rekey(self):
-        meta = MetaData(testing.db)
-
         class User(Base, fixtures.ComparableEntity):
 
             __tablename__ = 'users'
@@ -114,8 +111,6 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
         assert_raises(TypeError, User, name='u3')
 
     def test_supplied_fk(self):
-        meta = MetaData(testing.db)
-
         class IMHandle(Base, fixtures.ComparableEntity):
 
             __tablename__ = 'imhandles'
@@ -151,7 +146,7 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
 
 class DeferredReflectBase(DeclarativeReflectionBase):
     def teardown(self):
-        super(DeferredReflectBase,self).teardown()
+        super(DeferredReflectBase, self).teardown()
         from sqlalchemy.ext.declarative.base import _MapperConfig
         _MapperConfig.configs.clear()
 
@@ -275,7 +270,7 @@ class DeferredReflectionTest(DeferredReflectBase):
             @decl.declared_attr
             def __mapper_args__(cls):
                 return {
-                    "order_by":cls.__table__.c.name
+                    "order_by": cls.__table__.c.name
                 }
 
         decl.DeferredReflection.prepare(testing.db)
@@ -297,6 +292,65 @@ class DeferredReflectionTest(DeferredReflectBase):
             ]
         )
 
+class DeferredSecondaryReflectionTest(DeferredReflectBase):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('users', metadata,
+            Column('id', Integer,
+                primary_key=True, test_needs_autoincrement=True),
+              Column('name', String(50)), test_needs_fk=True)
+
+        Table('user_items', metadata,
+            Column('user_id', ForeignKey('users.id'), primary_key=True),
+            Column('item_id', ForeignKey('items.id'), primary_key=True),
+            test_needs_fk=True
+            )
+
+        Table('items', metadata,
+                Column('id', Integer, primary_key=True,
+                            test_needs_autoincrement=True),
+                Column('name', String(50)),
+                test_needs_fk=True
+            )
+
+    def _roundtrip(self):
+
+        User = Base._decl_class_registry['User']
+        Item = Base._decl_class_registry['Item']
+
+        u1 = User(name='u1', items=[Item(name='i1'), Item(name='i2')])
+
+        sess = Session()
+        sess.add(u1)
+        sess.commit()
+
+        eq_(sess.query(User).all(), [User(name='u1',
+            items=[Item(name='i1'), Item(name='i2')])])
+
+    def test_string_resolution(self):
+        class User(decl.DeferredReflection, fixtures.ComparableEntity, Base):
+            __tablename__ = 'users'
+
+            items = relationship("Item", secondary="user_items")
+
+        class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base):
+            __tablename__ = 'items'
+
+        decl.DeferredReflection.prepare(testing.db)
+        self._roundtrip()
+
+    def test_table_resolution(self):
+        class User(decl.DeferredReflection, fixtures.ComparableEntity, Base):
+            __tablename__ = 'users'
+
+            items = relationship("Item", secondary=Table("user_items", Base.metadata))
+
+        class Item(decl.DeferredReflection, fixtures.ComparableEntity, Base):
+            __tablename__ = 'items'
+
+        decl.DeferredReflection.prepare(testing.db)
+        self._roundtrip()
+
 class DeferredInhReflectBase(DeferredReflectBase):
     def _roundtrip(self):
         Foo = Base._decl_class_registry['Foo']
@@ -338,11 +392,11 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
 
         class Bar(Foo):
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
 
         decl.DeferredReflection.prepare(testing.db)
         self._roundtrip()
@@ -351,11 +405,11 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
 
         class Bar(Foo):
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
             bar_data = Column(String(30))
 
         decl.DeferredReflection.prepare(testing.db)
@@ -365,12 +419,12 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
             id = Column(Integer, primary_key=True)
 
         class Bar(Foo):
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
 
         decl.DeferredReflection.prepare(testing.db)
         self._roundtrip()
@@ -395,12 +449,12 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
 
         class Bar(Foo):
             __tablename__ = 'bar'
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
 
         decl.DeferredReflection.prepare(testing.db)
         self._roundtrip()
@@ -409,12 +463,12 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
 
         class Bar(Foo):
             __tablename__ = 'bar'
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
             bar_data = Column(String(30))
 
         decl.DeferredReflection.prepare(testing.db)
@@ -424,13 +478,13 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
             id = Column(Integer, primary_key=True)
 
         class Bar(Foo):
             __tablename__ = 'bar'
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
 
         decl.DeferredReflection.prepare(testing.db)
         self._roundtrip()
@@ -439,12 +493,12 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         class Foo(decl.DeferredReflection, fixtures.ComparableEntity,
                     Base):
             __tablename__ = 'foo'
-            __mapper_args__ = {"polymorphic_on":"type",
-                        "polymorphic_identity":"foo"}
+            __mapper_args__ = {"polymorphic_on": "type",
+                        "polymorphic_identity": "foo"}
 
         class Bar(Foo):
             __tablename__ = 'bar'
-            __mapper_args__ = {"polymorphic_identity":"bar"}
+            __mapper_args__ = {"polymorphic_identity": "bar"}
             id = Column(Integer, ForeignKey('foo.id'), primary_key=True)
 
         decl.DeferredReflection.prepare(testing.db)