From 800efc75256283770d5c28ddd99f26f341733698 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 14 Jul 2012 15:41:31 -0400 Subject: [PATCH] - [feature] *Very limited* support for inheriting mappers to be GC'ed when the class itself is deferenced. The mapper must not have its own table (i.e. single table inh only) without polymorphic attributes in place. This allows for the use case of creating a temporary subclass of a declarative mapped class, with no table or mapping directives of its own, to be garbage collected when dereferenced by a unit test. [ticket:2526] --- CHANGES | 16 +++++- lib/sqlalchemy/orm/mapper.py | 13 ++--- lib/sqlalchemy/util/__init__.py | 4 +- lib/sqlalchemy/util/_collections.py | 34 ++++++++---- lib/sqlalchemy/util/compat.py | 21 ++++++++ test/orm/inheritance/test_abc_polymorphic.py | 14 ++--- test/orm/inheritance/test_basic.py | 57 ++++++++++++++++++++ test/perf/orm2010.py | 32 +++++------ 8 files changed, 150 insertions(+), 41 deletions(-) diff --git a/CHANGES b/CHANGES index e94c55601e..c5402cd11d 100644 --- a/CHANGES +++ b/CHANGES @@ -157,7 +157,21 @@ underneath "0.7.xx". "Base" that are dereferenced will be garbage collected, *if they are not referred to by any other mappers/superclass - mappers*. [ticket:2526] + mappers*. See the next note for this ticket. + [ticket:2526] + + - [feature] *Very limited* support for + inheriting mappers to be GC'ed when the + class itself is deferenced. The mapper + must not have its own table (i.e. + single table inh only) without polymorphic + attributes in place. + This allows for the use case of + creating a temporary subclass of a declarative + mapped class, with no table or mapping + directives of its own, to be garbage collected + when dereferenced by a unit test. + [ticket:2526] - [removed] Deprecated identifiers removed: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 76d6b1165f..57c8de4986 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -431,7 +431,7 @@ class Mapper(object): being present.""" # a set of all mappers which inherit from this one. - self._inheriting_mappers = set() + self._inheriting_mappers = util.WeakSet() if self.inherits: if isinstance(self.inherits, type): @@ -1308,7 +1308,6 @@ class Mapper(object): tables = set(sql_util.find_tables(selectable, include_aliases=True)) mappers = [m for m in mappers if m.local_table in tables] - return mappers def _selectable_from_mappers(self, mappers, innerjoin): @@ -1383,7 +1382,7 @@ class Mapper(object): @_memoized_configured_property def _polymorphic_properties(self): - return tuple(self._iterate_polymorphic_properties( + return list(self._iterate_polymorphic_properties( self._with_polymorphic_mappers)) def _iterate_polymorphic_properties(self, mappers=None): @@ -1401,8 +1400,10 @@ class Mapper(object): # from other mappers, as these are sometimes dependent on that # mapper's polymorphic selectable (which we don't want rendered) for c in util.unique_list( - chain(*[list(mapper.iterate_properties) for mapper in [self] + - mappers]) + chain(*[ + list(mapper.iterate_properties) for mapper in + [self] + mappers + ]) ): if getattr(c, '_is_polymorphic_discriminator', False) and \ (self.polymorphic_on is None or @@ -1587,7 +1588,7 @@ class Mapper(object): item = stack.popleft() descendants.append(item) stack.extend(item._inheriting_mappers) - return tuple(descendants) + return util.WeakSequence(descendants) def polymorphic_iterator(self): """Iterate through the collection including this mapper and diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index cf61eb02a2..313c6b02c0 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -7,7 +7,7 @@ from compat import callable, cmp, reduce, defaultdict, py25_dict, \ threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \ update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\ - parse_qsl, any, contextmanager, namedtuple, next + parse_qsl, any, contextmanager, namedtuple, next, WeakSet from _collections import NamedTuple, ImmutableContainer, immutabledict, \ Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ @@ -15,7 +15,7 @@ from _collections import NamedTuple, ImmutableContainer, immutabledict, \ column_dict, ordered_column_set, populate_column_dict, unique_list, \ UniqueAppender, PopulateDict, EMPTY_SET, to_list, to_set, \ to_column_set, update_copy, flatten_iterator, \ - LRUCache, ScopedRegistry, ThreadLocalRegistry + LRUCache, ScopedRegistry, ThreadLocalRegistry, WeakSequence from langhelpers import iterate_attributes, class_hierarchy, \ portable_instancemethod, unbound_method_to_callable, \ diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index d2ed091f4f..801a79e9a7 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -210,7 +210,7 @@ class OrderedDict(dict): try: self._list.append(key) except AttributeError: - # work around Python pickle loads() with + # work around Python pickle loads() with # dict subclass (seems to ignore __setstate__?) self._list = [key] dict.__setitem__(self, key, object) @@ -539,6 +539,20 @@ class IdentitySet(object): def __repr__(self): return '%s(%r)' % (type(self).__name__, self._members.values()) +class WeakSequence(object): + def __init__(self, elements): + self._storage = weakref.WeakValueDictionary( + (idx, element) for idx, element in enumerate(elements) + ) + + def __iter__(self): + return self._storage.itervalues() + + def __getitem__(self, index): + try: + return self._storage[index] + except KeyError: + raise IndexError("Index %s out of range" % index) class OrderedIdentitySet(IdentitySet): class _working_set(OrderedSet): @@ -585,7 +599,7 @@ else: self[key] = value = self.creator(key) return value -# define collections that are capable of storing +# define collections that are capable of storing # ColumnElement objects as hashable keys/elements. column_set = set column_dict = dict @@ -595,12 +609,12 @@ populate_column_dict = PopulateDict def unique_list(seq, hashfunc=None): seen = {} if not hashfunc: - return [x for x in seq - if x not in seen + return [x for x in seq + if x not in seen and not seen.__setitem__(x, True)] else: - return [x for x in seq - if hashfunc(x) not in seen + return [x for x in seq + if hashfunc(x) not in seen and not seen.__setitem__(hashfunc(x), True)] class UniqueAppender(object): @@ -716,15 +730,15 @@ class LRUCache(dict): def _manage_size(self): while len(self) > self.capacity + self.capacity * self.threshold: - by_counter = sorted(dict.values(self), + by_counter = sorted(dict.values(self), key=operator.itemgetter(2), reverse=True) for item in by_counter[self.capacity:]: try: del self[item[0]] except KeyError: - # if we couldnt find a key, most - # likely some other thread broke in + # if we couldnt find a key, most + # likely some other thread broke in # on us. loop around and try again break @@ -785,7 +799,7 @@ class ScopedRegistry(object): pass class ThreadLocalRegistry(ScopedRegistry): - """A :class:`.ScopedRegistry` that uses a ``threading.local()`` + """A :class:`.ScopedRegistry` that uses a ``threading.local()`` variable for storage. """ diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 5dc59b5c52..215a68e913 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -166,6 +166,27 @@ except ImportError: return 'defaultdict(%s, %s)' % (self.default_factory, dict.__repr__(self)) +try: + from weakref import WeakSet +except: + import weakref + + class WeakSet(object): + """Implement the small subset of set() which SQLAlchemy needs + here. """ + def __init__(self, values=None): + self._storage = weakref.WeakKeyDictionary() + if values is not None: + self._storage.update((value, None) for value in values) + + def __iter__(self): + return iter(self._storage) + + def union(self, other): + return WeakSet(set(self).union(other)) + + def add(self, other): + self._storage[other] = True # find or create a dict implementation that supports __missing__ class _probe(dict): diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 8d1b5fec08..d758924205 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import * from test.lib.util import function_named from test.lib import fixtures from test.orm import _fixtures +from test.lib.testing import eq_ from test.lib.schema import Table, Column class ABCTest(fixtures.MappedTest): @@ -55,7 +56,8 @@ class ABCTest(fixtures.MappedTest): #for obj in sess.query(A).all(): # print obj - assert [ + eq_( + [ A(adata='a1'), B(bdata='b1', adata='b1'), B(bdata='b2', adata='b2'), @@ -63,22 +65,22 @@ class ABCTest(fixtures.MappedTest): C(cdata='c1', bdata='c1', adata='c1'), C(cdata='c2', bdata='c2', adata='c2'), C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(A).order_by(A.id).all() + ], sess.query(A).order_by(A.id).all()) - assert [ + eq_([ B(bdata='b1', adata='b1'), B(bdata='b2', adata='b2'), B(bdata='b3', adata='b3'), C(cdata='c1', bdata='c1', adata='c1'), C(cdata='c2', bdata='c2', adata='c2'), C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(B).all() + ], sess.query(B).all()) - assert [ + eq_([ C(cdata='c1', bdata='c1', adata='c1'), C(cdata='c2', bdata='c2', adata='c2'), C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(C).all() + ], sess.query(C).all()) test_roundtrip = function_named( test_roundtrip, 'test_%s' % fetchtype) diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 0e647a66e3..28aced07b2 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -11,6 +11,7 @@ from test.lib import fixtures from test.orm import _fixtures from test.lib.schema import Table, Column from sqlalchemy.ext.declarative import declarative_base +from test.lib.util import gc_collect class O2MTest(fixtures.MappedTest): """deals with inheritance and one-to-many relationships""" @@ -1695,6 +1696,62 @@ class OptimizedLoadTest(fixtures.MappedTest): ), ) +class TransientInheritingGCTest(fixtures.TestBase): + def _fixture(self): + Base = declarative_base() + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True, test_needs_pk=True) + data = Column(String(10)) + self.A = A + return Base + + def setUp(self): + self.Base = self._fixture() + + def tearDown(self): + self.Base.metadata.drop_all(testing.db) + #clear_mappers() + self.Base = None + + def _do_test(self, go): + B = go() + self.Base.metadata.create_all(testing.db) + sess = Session(testing.db) + sess.add(B(data='some b')) + sess.commit() + + b1 = sess.query(B).one() + assert isinstance(b1, B) + sess.close() + del sess + del b1 + del B + + gc_collect() + + eq_( + len(self.A.__subclasses__()), + 0) + + def test_single(self): + def go(): + class B(self.A): + pass + return B + self._do_test(go) + + @testing.fails_if(lambda: True, + "not supported for joined inh right now.") + def test_joined(self): + def go(): + class B(self.A): + __tablename__ = 'b' + id = Column(Integer, ForeignKey('a.id'), + primary_key=True) + return B + self._do_test(go) + class NoPKOnSubTableWarningTest(fixtures.TestBase): def _fixture(self): diff --git a/test/perf/orm2010.py b/test/perf/orm2010.py index 23bad9c696..ed4e7b0908 100644 --- a/test/perf/orm2010.py +++ b/test/perf/orm2010.py @@ -51,36 +51,36 @@ class Grunt(Employee): employer_id = Column(Integer, ForeignKey('boss.id')) - # Configure an 'employer' relationship, where Grunt references - # Boss. This is a joined-table subclass to subclass relationship, + # Configure an 'employer' relationship, where Grunt references + # Boss. This is a joined-table subclass to subclass relationship, # which is a less typical case. # In 0.7, "Boss.id" is the "id" column of "boss", as would be expected. if __version__ >= "0.7": - employer = relationship("Boss", backref="employees", + employer = relationship("Boss", backref="employees", primaryjoin=Boss.id==employer_id) # Prior to 0.7, "Boss.id" is the "id" column of "employee". # Long story. So we hardwire the relationship against the "id" # column of Boss' table. elif __version__ >= "0.6": - employer = relationship("Boss", backref="employees", + employer = relationship("Boss", backref="employees", primaryjoin=Boss.__table__.c.id==employer_id) - # In 0.5, the many-to-one loader wouldn't recognize the above as a + # In 0.5, the many-to-one loader wouldn't recognize the above as a # simple "identity map" fetch. So to give 0.5 a chance to emit # the same amount of SQL as 0.6, we hardwire the relationship against # "employee.id" to work around the bug. else: - employer = relationship("Boss", backref="employees", - primaryjoin=Employee.__table__.c.id==employer_id, + employer = relationship("Boss", backref="employees", + primaryjoin=Employee.__table__.c.id==employer_id, foreign_keys=employer_id) __mapper_args__ = {'polymorphic_identity':'grunt'} if os.path.exists('orm2010.db'): os.remove('orm2010.db') -# use a file based database so that cursor.execute() has some +# use a file based database so that cursor.execute() has some # palpable overhead. engine = create_engine('sqlite:///orm2010.db') @@ -92,7 +92,7 @@ def runit(): # create 1000 Boss objects. bosses = [ Boss( - name="Boss %d" % i, + name="Boss %d" % i, golf_average=Decimal(random.randint(40, 150)) ) for i in xrange(1000) @@ -111,9 +111,9 @@ def runit(): ] # Assign each Grunt a Boss. Look them up in the DB - # to simulate a little bit of two-way activity with the + # to simulate a little bit of two-way activity with the # DB while we populate. Autoflush occurs on each query. - # In 0.7 executemany() is used for all the "boss" and "grunt" + # In 0.7 executemany() is used for all the "boss" and "grunt" # tables since priamry key fetching is not needed. while grunts: boss = sess.query(Boss).\ @@ -131,13 +131,13 @@ def runit(): # load all the Grunts, print a report with their name, stats, # and their bosses' stats. for grunt in sess.query(Grunt): - # here, the overhead of a many-to-one fetch of - # "grunt.employer" directly from the identity map + # here, the overhead of a many-to-one fetch of + # "grunt.employer" directly from the identity map # is less than half of that of 0.6. report.append(( - grunt.name, - grunt.savings, - grunt.employer.name, + grunt.name, + grunt.savings, + grunt.employer.name, grunt.employer.golf_average )) -- 2.47.3