]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] *Very limited* support for
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Jul 2012 19:41:31 +0000 (15:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Jul 2012 19:41:31 +0000 (15:41 -0400)
    inheriting mappers to be GC'ed when the
    class itself is deferenced.  The mapper
    must not have its own table (i.e.
    single table inh only) without polymorphic
    attributes in place.
    This allows for the use case of
    creating a temporary subclass of a declarative
    mapped class, with no table or mapping
    directives of its own, to be garbage collected
    when dereferenced by a unit test.
    [ticket:2526]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/compat.py
test/orm/inheritance/test_abc_polymorphic.py
test/orm/inheritance/test_basic.py
test/perf/orm2010.py

diff --git a/CHANGES b/CHANGES
index e94c55601e86a8572eeeaf11f8dd0c559254a560..c5402cd11d95daab57911285cef59ac75dd8cd01 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -157,7 +157,21 @@ underneath "0.7.xx".
     "Base" that are dereferenced will be
     garbage collected, *if they are not
     referred to by any other mappers/superclass
-    mappers*.  [ticket:2526]
+    mappers*. See the next note for this ticket.
+    [ticket:2526]
+
+  - [feature] *Very limited* support for
+    inheriting mappers to be GC'ed when the
+    class itself is deferenced.  The mapper
+    must not have its own table (i.e.
+    single table inh only) without polymorphic
+    attributes in place.
+    This allows for the use case of
+    creating a temporary subclass of a declarative
+    mapped class, with no table or mapping
+    directives of its own, to be garbage collected
+    when dereferenced by a unit test.
+    [ticket:2526]
 
   - [removed] Deprecated identifiers removed:
 
index 76d6b1165f25c86bff40e367f2c215572902884f..57c8de49869d7ba2880b36469fff4fe247a734cb 100644 (file)
@@ -431,7 +431,7 @@ class Mapper(object):
         being present."""
 
         # a set of all mappers which inherit from this one.
-        self._inheriting_mappers = set()
+        self._inheriting_mappers = util.WeakSet()
 
         if self.inherits:
             if isinstance(self.inherits, type):
@@ -1308,7 +1308,6 @@ class Mapper(object):
             tables = set(sql_util.find_tables(selectable,
                             include_aliases=True))
             mappers = [m for m in mappers if m.local_table in tables]
-
         return mappers
 
     def _selectable_from_mappers(self, mappers, innerjoin):
@@ -1383,7 +1382,7 @@ class Mapper(object):
 
     @_memoized_configured_property
     def _polymorphic_properties(self):
-        return tuple(self._iterate_polymorphic_properties(
+        return list(self._iterate_polymorphic_properties(
             self._with_polymorphic_mappers))
 
     def _iterate_polymorphic_properties(self, mappers=None):
@@ -1401,8 +1400,10 @@ class Mapper(object):
             # from other mappers, as these are sometimes dependent on that
             # mapper's polymorphic selectable (which we don't want rendered)
             for c in util.unique_list(
-                chain(*[list(mapper.iterate_properties) for mapper in [self] +
-                        mappers])
+                chain(*[
+                        list(mapper.iterate_properties) for mapper in
+                        [self] + mappers
+                    ])
             ):
                 if getattr(c, '_is_polymorphic_discriminator', False) and \
                         (self.polymorphic_on is None or
@@ -1587,7 +1588,7 @@ class Mapper(object):
             item = stack.popleft()
             descendants.append(item)
             stack.extend(item._inheriting_mappers)
-        return tuple(descendants)
+        return util.WeakSequence(descendants)
 
     def polymorphic_iterator(self):
         """Iterate through the collection including this mapper and
index cf61eb02a221adf7214e82bdecf8a54cc3a7e4d9..313c6b02c04ce86445e606b764aa1a783c877fab 100644 (file)
@@ -7,7 +7,7 @@
 from compat import callable, cmp, reduce, defaultdict, py25_dict, \
     threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \
     update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\
-    parse_qsl, any, contextmanager, namedtuple, next
+    parse_qsl, any, contextmanager, namedtuple, next, WeakSet
 
 from _collections import NamedTuple, ImmutableContainer, immutabledict, \
     Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
@@ -15,7 +15,7 @@ from _collections import NamedTuple, ImmutableContainer, immutabledict, \
     column_dict, ordered_column_set, populate_column_dict, unique_list, \
     UniqueAppender, PopulateDict, EMPTY_SET, to_list, to_set, \
     to_column_set, update_copy, flatten_iterator, \
-    LRUCache, ScopedRegistry, ThreadLocalRegistry
+    LRUCache, ScopedRegistry, ThreadLocalRegistry, WeakSequence
 
 from langhelpers import iterate_attributes, class_hierarchy, \
     portable_instancemethod, unbound_method_to_callable, \
index d2ed091f4f6e472bf7f13ce147d53b866dfdce87..801a79e9a7aa2add82e8340b460fc11a03f93a60 100644 (file)
@@ -210,7 +210,7 @@ class OrderedDict(dict):
             try:
                 self._list.append(key)
             except AttributeError:
-                # work around Python pickle loads() with 
+                # work around Python pickle loads() with
                 # dict subclass (seems to ignore __setstate__?)
                 self._list = [key]
         dict.__setitem__(self, key, object)
@@ -539,6 +539,20 @@ class IdentitySet(object):
     def __repr__(self):
         return '%s(%r)' % (type(self).__name__, self._members.values())
 
+class WeakSequence(object):
+    def __init__(self, elements):
+        self._storage = weakref.WeakValueDictionary(
+            (idx, element) for idx, element in enumerate(elements)
+        )
+
+    def __iter__(self):
+        return self._storage.itervalues()
+
+    def __getitem__(self, index):
+        try:
+            return self._storage[index]
+        except KeyError:
+            raise IndexError("Index %s out of range" % index)
 
 class OrderedIdentitySet(IdentitySet):
     class _working_set(OrderedSet):
@@ -585,7 +599,7 @@ else:
                 self[key] = value = self.creator(key)
                 return value
 
-# define collections that are capable of storing 
+# define collections that are capable of storing
 # ColumnElement objects as hashable keys/elements.
 column_set = set
 column_dict = dict
@@ -595,12 +609,12 @@ populate_column_dict = PopulateDict
 def unique_list(seq, hashfunc=None):
     seen = {}
     if not hashfunc:
-        return [x for x in seq 
-                if x not in seen 
+        return [x for x in seq
+                if x not in seen
                 and not seen.__setitem__(x, True)]
     else:
-        return [x for x in seq 
-                if hashfunc(x) not in seen 
+        return [x for x in seq
+                if hashfunc(x) not in seen
                 and not seen.__setitem__(hashfunc(x), True)]
 
 class UniqueAppender(object):
@@ -716,15 +730,15 @@ class LRUCache(dict):
 
     def _manage_size(self):
         while len(self) > self.capacity + self.capacity * self.threshold:
-            by_counter = sorted(dict.values(self), 
+            by_counter = sorted(dict.values(self),
                             key=operator.itemgetter(2),
                             reverse=True)
             for item in by_counter[self.capacity:]:
                 try:
                     del self[item[0]]
                 except KeyError:
-                    # if we couldnt find a key, most 
-                    # likely some other thread broke in 
+                    # if we couldnt find a key, most
+                    # likely some other thread broke in
                     # on us. loop around and try again
                     break
 
@@ -785,7 +799,7 @@ class ScopedRegistry(object):
             pass
 
 class ThreadLocalRegistry(ScopedRegistry):
-    """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 
+    """A :class:`.ScopedRegistry` that uses a ``threading.local()``
     variable for storage.
 
     """
index 5dc59b5c52753d8d8bb7fa8f9372e5f609d38c72..215a68e913f60f1ea9dc355c6dbf105aadb214d1 100644 (file)
@@ -166,6 +166,27 @@ except ImportError:
             return 'defaultdict(%s, %s)' % (self.default_factory,
                                             dict.__repr__(self))
 
+try:
+    from weakref import WeakSet
+except:
+    import weakref
+
+    class WeakSet(object):
+        """Implement the small subset of set() which SQLAlchemy needs
+        here. """
+        def __init__(self, values=None):
+            self._storage = weakref.WeakKeyDictionary()
+            if values is not None:
+                self._storage.update((value, None) for value in values)
+
+        def __iter__(self):
+            return iter(self._storage)
+
+        def union(self, other):
+            return WeakSet(set(self).union(other))
+
+        def add(self, other):
+            self._storage[other] = True
 
 # find or create a dict implementation that supports __missing__
 class _probe(dict):
index 8d1b5fec08ef6daa1227cd6bd78c36f85bb576e9..d7589242051018e271e16183511d2a94971cbb8c 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy.orm import *
 from test.lib.util import function_named
 from test.lib import fixtures
 from test.orm import _fixtures
+from test.lib.testing import eq_
 from test.lib.schema import Table, Column
 
 class ABCTest(fixtures.MappedTest):
@@ -55,7 +56,8 @@ class ABCTest(fixtures.MappedTest):
 
             #for obj in sess.query(A).all():
             #    print obj
-            assert [
+            eq_(
+                [
                 A(adata='a1'),
                 B(bdata='b1', adata='b1'),
                 B(bdata='b2', adata='b2'),
@@ -63,22 +65,22 @@ class ABCTest(fixtures.MappedTest):
                 C(cdata='c1', bdata='c1', adata='c1'),
                 C(cdata='c2', bdata='c2', adata='c2'),
                 C(cdata='c2', bdata='c2', adata='c2'),
-            ] == sess.query(A).order_by(A.id).all()
+            ], sess.query(A).order_by(A.id).all())
 
-            assert [
+            eq_([
                 B(bdata='b1', adata='b1'),
                 B(bdata='b2', adata='b2'),
                 B(bdata='b3', adata='b3'),
                 C(cdata='c1', bdata='c1', adata='c1'),
                 C(cdata='c2', bdata='c2', adata='c2'),
                 C(cdata='c2', bdata='c2', adata='c2'),
-            ] == sess.query(B).all()
+            ], sess.query(B).all())
 
-            assert [
+            eq_([
                 C(cdata='c1', bdata='c1', adata='c1'),
                 C(cdata='c2', bdata='c2', adata='c2'),
                 C(cdata='c2', bdata='c2', adata='c2'),
-            ] == sess.query(C).all()
+            ], sess.query(C).all())
 
         test_roundtrip = function_named(
             test_roundtrip, 'test_%s' % fetchtype)
index 0e647a66e3249249e2da65e0ad756b40efdaa26e..28aced07b2a54d07af41e19ae638125209b7ff33 100644 (file)
@@ -11,6 +11,7 @@ from test.lib import fixtures
 from test.orm import _fixtures
 from test.lib.schema import Table, Column
 from sqlalchemy.ext.declarative import declarative_base
+from test.lib.util import gc_collect
 
 class O2MTest(fixtures.MappedTest):
     """deals with inheritance and one-to-many relationships"""
@@ -1695,6 +1696,62 @@ class OptimizedLoadTest(fixtures.MappedTest):
             ),
         )
 
+class TransientInheritingGCTest(fixtures.TestBase):
+    def _fixture(self):
+        Base = declarative_base()
+        class A(Base):
+            __tablename__ = 'a'
+            id = Column(Integer, primary_key=True, test_needs_pk=True)
+            data = Column(String(10))
+        self.A = A
+        return Base
+
+    def setUp(self):
+        self.Base = self._fixture()
+
+    def tearDown(self):
+        self.Base.metadata.drop_all(testing.db)
+        #clear_mappers()
+        self.Base = None
+
+    def _do_test(self, go):
+        B = go()
+        self.Base.metadata.create_all(testing.db)
+        sess = Session(testing.db)
+        sess.add(B(data='some b'))
+        sess.commit()
+
+        b1 = sess.query(B).one()
+        assert isinstance(b1, B)
+        sess.close()
+        del sess
+        del b1
+        del B
+
+        gc_collect()
+
+        eq_(
+            len(self.A.__subclasses__()),
+            0)
+
+    def test_single(self):
+        def go():
+            class B(self.A):
+                pass
+            return B
+        self._do_test(go)
+
+    @testing.fails_if(lambda: True,
+                "not supported for joined inh right now.")
+    def test_joined(self):
+        def go():
+            class B(self.A):
+                __tablename__ = 'b'
+                id = Column(Integer, ForeignKey('a.id'),
+                        primary_key=True)
+            return B
+        self._do_test(go)
+
 class NoPKOnSubTableWarningTest(fixtures.TestBase):
 
     def _fixture(self):
index 23bad9c6969182fd01bac64588ed91d7f48d72d6..ed4e7b090866fbc46dff1bf1fb8ffd4e0086914a 100644 (file)
@@ -51,36 +51,36 @@ class Grunt(Employee):
 
     employer_id = Column(Integer, ForeignKey('boss.id'))
 
-    # Configure an 'employer' relationship, where Grunt references 
-    # Boss.  This is a joined-table subclass to subclass relationship, 
+    # Configure an 'employer' relationship, where Grunt references
+    # Boss.  This is a joined-table subclass to subclass relationship,
     # which is a less typical case.
 
     # In 0.7, "Boss.id" is the "id" column of "boss", as would be expected.
     if __version__ >= "0.7":
-        employer = relationship("Boss", backref="employees", 
+        employer = relationship("Boss", backref="employees",
                                     primaryjoin=Boss.id==employer_id)
 
     # Prior to 0.7, "Boss.id" is the "id" column of "employee".
     # Long story.  So we hardwire the relationship against the "id"
     # column of Boss' table.
     elif __version__ >= "0.6":
-        employer = relationship("Boss", backref="employees", 
+        employer = relationship("Boss", backref="employees",
                                 primaryjoin=Boss.__table__.c.id==employer_id)
 
-    # In 0.5, the many-to-one loader wouldn't recognize the above as a 
+    # In 0.5, the many-to-one loader wouldn't recognize the above as a
     # simple "identity map" fetch.  So to give 0.5 a chance to emit
     # the same amount of SQL as 0.6, we hardwire the relationship against
     # "employee.id" to work around the bug.
     else:
-        employer = relationship("Boss", backref="employees", 
-                                primaryjoin=Employee.__table__.c.id==employer_id, 
+        employer = relationship("Boss", backref="employees",
+                                primaryjoin=Employee.__table__.c.id==employer_id,
                                 foreign_keys=employer_id)
 
     __mapper_args__ = {'polymorphic_identity':'grunt'}
 
 if os.path.exists('orm2010.db'):
     os.remove('orm2010.db')
-# use a file based database so that cursor.execute() has some 
+# use a file based database so that cursor.execute() has some
 # palpable overhead.
 engine = create_engine('sqlite:///orm2010.db')
 
@@ -92,7 +92,7 @@ def runit():
     # create 1000 Boss objects.
     bosses = [
         Boss(
-            name="Boss %d" % i, 
+            name="Boss %d" % i,
             golf_average=Decimal(random.randint(40, 150))
         )
         for i in xrange(1000)
@@ -111,9 +111,9 @@ def runit():
     ]
 
     # Assign each Grunt a Boss.  Look them up in the DB
-    # to simulate a little bit of two-way activity with the 
+    # to simulate a little bit of two-way activity with the
     # DB while we populate.  Autoflush occurs on each query.
-    # In 0.7 executemany() is used for all the "boss" and "grunt" 
+    # In 0.7 executemany() is used for all the "boss" and "grunt"
     # tables since priamry key fetching is not needed.
     while grunts:
         boss = sess.query(Boss).\
@@ -131,13 +131,13 @@ def runit():
     # load all the Grunts, print a report with their name, stats,
     # and their bosses' stats.
     for grunt in sess.query(Grunt):
-        # here, the overhead of a many-to-one fetch of 
-        # "grunt.employer" directly from the identity map 
+        # here, the overhead of a many-to-one fetch of
+        # "grunt.employer" directly from the identity map
         # is less than half of that of 0.6.
         report.append((
-                        grunt.name, 
-                        grunt.savings, 
-                        grunt.employer.name, 
+                        grunt.name,
+                        grunt.savings,
+                        grunt.employer.name,
                         grunt.employer.golf_average
                     ))