From: Mike Bayer Date: Mon, 31 May 2010 15:56:08 +0000 (-0400) Subject: change the weakkeydict to be just an LRU cache. Add tests X-Git-Tag: rel_0_6_1~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cfe9fadc61cfa05c71255fc0e447360199054ffc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git change the weakkeydict to be just an LRU cache. Add tests for the "many combinations of UPDATE keys" issue. --- diff --git a/CHANGES b/CHANGES index ff379d1973..2648510b70 100644 --- a/CHANGES +++ b/CHANGES @@ -30,7 +30,9 @@ CHANGES the expression objects are cached by the mapper after the first create, and their compiled form is stored persistently in a cache dictionary for the duration of - the related Engine. + the related Engine. The cache is an LRUCache for the + rare case that a mapper receives an extremely + high number of different column patterns as UPDATEs. - sql - expr.in_() now accepts a text() construct as the argument. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 94e969bcc4..4df51928bd 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -43,6 +43,7 @@ _new_mappers = False _already_compiling = False _none_set = frozenset([None]) + # a list of MapperExtensions that will be installed in all mappers by default global_extensions = [] @@ -97,7 +98,9 @@ class Mapper(object): include_properties=None, exclude_properties=None, passive_updates=True, - eager_defaults=False): + eager_defaults=False, + _compiled_cache_size=100, + ): """Construct a new mapper. Mappers are normally constructed via the :func:`~sqlalchemy.orm.mapper` @@ -140,6 +143,7 @@ class Mapper(object): self._requires_row_aliasing = False self._inherits_equated_pairs = None self._memoized_values = {} + self._compiled_cache_size = _compiled_cache_size if allow_null_pks: util.warn_deprecated('the allow_null_pks option to Mapper() is ' @@ -1264,7 +1268,7 @@ class Mapper(object): @util.memoized_property def _compiled_cache(self): - return weakref.WeakKeyDictionary() + return util.LRUCache(self._compiled_cache_size) @util.memoized_property def _sorted_tables(self): @@ -1342,7 +1346,7 @@ class Mapper(object): cached_connections = util.PopulateDict( lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache.setdefault(conn.engine, {}) + compiled_cache=self._compiled_cache )) # if session has a connection callable, @@ -1740,7 +1744,7 @@ class Mapper(object): tups = [] cached_connections = util.PopulateDict( lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache.setdefault(conn.engine, {}) + compiled_cache=self._compiled_cache )) for state in _sort_states(states): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index c2c85a4c93..4b04901a1a 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -19,6 +19,7 @@ except ImportError: py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) jython = sys.platform.startswith('java') +win32 = sys.platform.startswith('win') if py3k: set_types = set @@ -1542,6 +1543,54 @@ class WeakIdentityMapping(weakref.WeakKeyDictionary): def _ref(self, object): return self._keyed_weakref(object, self._cleanup) +import time +if win32 or jython: + time_func = time.clock +else: + time_func = time.time + +class LRUCache(dict): + def __init__(self, capacity=100, threshold=.5): + self.capacity = capacity + self.threshold = threshold + + def __getitem__(self, key): + item = dict.__getitem__(self, key) + item[2] = time_func() + return item[1] + + def values(self): + return [i[1] for i in dict.values(self)] + + def setdefault(self, key, value): + if key in self: + return self[key] + else: + self[key] = value + return value + + def __setitem__(self, key, value): + item = dict.get(self, key) + if item is None: + item = [key, value, time_func()] + dict.__setitem__(self, key, item) + else: + item[1] = value + self._manage_size() + + def _manage_size(self): + while len(self) > self.capacity + self.capacity * self.threshold: + bytime = sorted(dict.values(self), + key=operator.itemgetter(2), + reverse=True) + for item in bytime[self.capacity:]: + try: + del self[item[0]] + except KeyError: + # if we couldnt find a key, most + # likely some other thread broke in + # on us. loop around and try again + break def warn(msg, stacklevel=3): if isinstance(msg, basestring): diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 2d64cd8046..2e1810eca6 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1,12 +1,13 @@ from sqlalchemy.test.testing import eq_ -from sqlalchemy.orm import mapper, relationship, create_session, clear_mappers, \ - sessionmaker, class_mapper +from sqlalchemy.orm import mapper, relationship, create_session,\ + clear_mappers, sessionmaker, class_mapper from sqlalchemy.orm.mapper import _mapper_registry from sqlalchemy.orm.session import _sessions from sqlalchemy.util import jython import operator from sqlalchemy.test import testing, engines -from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType, create_engine +from sqlalchemy import MetaData, Integer, String, ForeignKey, \ + PickleType, create_engine from sqlalchemy.test.schema import Table, Column import sqlalchemy as sa from sqlalchemy.sql import column @@ -48,9 +49,11 @@ def profile_memory(func): else: flatline = True - if not flatline and samples[-1] > samples[0]: # object count is bigger than when it started + # object count is bigger than when it started + if not flatline and samples[-1] > samples[0]: for x in samples[1:-2]: - if x > samples[-1]: # see if a spike bigger than the endpoint exists + # see if a spike bigger than the endpoint exists + if x > samples[-1]: break else: assert False, repr(samples) + " " + repr(flatline) @@ -85,18 +88,21 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) metadata.create_all() m1 = mapper(A, table1, properties={ - "bs":relationship(B, cascade="all, delete", order_by=table2.c.col1)}, + "bs":relationship(B, cascade="all, delete", + order_by=table2.c.col1)}, order_by=table1.c.col1) m2 = mapper(B, table2) @@ -139,26 +145,36 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) metadata.create_all() m1 = mapper(A, table1, properties={ - "bs":relationship(B, cascade="all, delete", order_by=table2.c.col1)}, - order_by=table1.c.col1) - m2 = mapper(B, table2) + "bs":relationship(B, cascade="all, delete", + order_by=table2.c.col1)}, + order_by=table1.c.col1, + _compiled_cache_size=10 + ) + m2 = mapper(B, table2, + _compiled_cache_size=10 + ) m3 = mapper(A, table1, non_primary=True) @profile_memory def go(): - engine = engines.testing_engine(options={'logging_name':'FOO', 'pool_logging_name':'BAR'}) + engine = engines.testing_engine( + options={'logging_name':'FOO', + 'pool_logging_name':'BAR'} + ) sess = create_session(bind=engine) a1 = A(col2="a1") @@ -192,15 +208,61 @@ class MemUsageTest(EnsureZeroed): del m1, m2, m3 assert_no_mappers() + def test_many_updates(self): + metadata = MetaData(testing.db) + + wide_table = Table('t', metadata, + Column('id', Integer, primary_key=True), + *[Column('col%d' % i, Integer) for i in range(10)] + ) + + class Wide(object): + pass + + mapper(Wide, wide_table, _compiled_cache_size=10) + + metadata.create_all() + session = create_session() + w1 = Wide() + session.add(w1) + session.flush() + session.close() + del session + counter = [1] + + @profile_memory + def go(): + session = create_session() + w1 = session.query(Wide).first() + x = counter[0] + dec = 10 + while dec > 0: + # trying to count in binary here, + # works enough to trip the test case + if pow(2, dec) < x: + setattr(w1, 'col%d' % dec, counter[0]) + x -= pow(2, dec) + dec -= 1 + session.flush() + session.close() + counter[0] += 1 + + try: + go() + finally: + metadata.drop_all() + def test_mapper_reset(self): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) @@ -251,7 +313,8 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30)) ) @@ -311,12 +374,14 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30)) ) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', String(30)), ) @@ -333,7 +398,8 @@ class MemUsageTest(EnsureZeroed): pass mapper(A, table1, properties={ - 'bs':relationship(B, secondary=table3, backref='as', order_by=table3.c.t1) + 'bs':relationship(B, secondary=table3, + backref='as', order_by=table3.c.t1) }) mapper(B, table2) @@ -381,12 +447,14 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("table1", metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), Column('data', String(30)) ) table2 = Table("table2", metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('table1.id')) ) @@ -420,7 +488,8 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), + Column('col1', Integer, primary_key=True, + test_needs_autoincrement=True), Column('col2', PickleType(comparator=operator.eq)) ) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 035e4f2682..f9888ef0c5 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -112,6 +112,52 @@ class ColumnCollectionTest(TestBase): assert (cc1==cc2).compare(c1 == c2) assert not (cc1==cc3).compare(c2 == c3) + + +class LRUTest(TestBase): + + def test_lru(self): + class item(object): + def __init__(self, id): + self.id = id + + def __str__(self): + return "item id %d" % self.id + + l = util.LRUCache(10, threshold=.2) + + for id in range(1,20): + l[id] = item(id) + + # first couple of items should be gone + assert 1 not in l + assert 2 not in l + + # next batch over the threshold of 10 should be present + for id_ in range(11,20): + assert id_ in l + + l[12] + l[15] + l[23] = item(23) + l[24] = item(24) + l[25] = item(25) + l[26] = item(26) + l[27] = item(27) + + assert 11 not in l + assert 13 not in l + + for id_ in (25, 24, 23, 14, 12, 19, 18, 17, 16, 15): + assert id_ in l + + i1 = l[25] + i2 = item(25) + l[25] = i2 + assert 25 in l + assert l[25] is i2 + + class ImmutableSubclass(str): pass