]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
change the weakkeydict to be just an LRU cache. Add tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 31 May 2010 15:56:08 +0000 (11:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 31 May 2010 15:56:08 +0000 (11:56 -0400)
for the "many combinations of UPDATE keys" issue.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/util.py
test/aaa_profiling/test_memusage.py
test/base/test_utils.py

diff --git a/CHANGES b/CHANGES
index ff379d197320e76033f501314a420fc4ac0e4967..2648510b70c9bc3f98b7dd9789446e3d4d354a9f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,7 +30,9 @@ CHANGES
     the expression objects are cached by the mapper after 
     the first create, and their compiled form is stored
     persistently in a cache dictionary for the duration of
-    the related Engine.
+    the related Engine.  The cache is an LRUCache for the
+    rare case that a mapper receives an extremely
+    high number of different column patterns as UPDATEs.
     
 - sql
   - expr.in_() now accepts a text() construct as the argument.
index 94e969bcc48f6830f4a9db07eea6f273f3a44c0d..4df51928bd61188147ab101713fca04f72911152 100644 (file)
@@ -43,6 +43,7 @@ _new_mappers = False
 _already_compiling = False
 _none_set = frozenset([None])
 
+
 # a list of MapperExtensions that will be installed in all mappers by default
 global_extensions = []
 
@@ -97,7 +98,9 @@ class Mapper(object):
                  include_properties=None,
                  exclude_properties=None,
                  passive_updates=True,
-                 eager_defaults=False):
+                 eager_defaults=False,
+                 _compiled_cache_size=100,
+                 ):
         """Construct a new mapper.
 
         Mappers are normally constructed via the :func:`~sqlalchemy.orm.mapper`
@@ -140,6 +143,7 @@ class Mapper(object):
         self._requires_row_aliasing = False
         self._inherits_equated_pairs = None
         self._memoized_values = {}
+        self._compiled_cache_size = _compiled_cache_size
         
         if allow_null_pks:
             util.warn_deprecated('the allow_null_pks option to Mapper() is '
@@ -1264,7 +1268,7 @@ class Mapper(object):
 
     @util.memoized_property
     def _compiled_cache(self):
-        return weakref.WeakKeyDictionary()
+        return util.LRUCache(self._compiled_cache_size)
 
     @util.memoized_property
     def _sorted_tables(self):
@@ -1342,7 +1346,7 @@ class Mapper(object):
 
         cached_connections = util.PopulateDict(
             lambda conn:conn.execution_options(
-            compiled_cache=self._compiled_cache.setdefault(conn.engine, {})
+            compiled_cache=self._compiled_cache
         ))
 
         # if session has a connection callable,
@@ -1740,7 +1744,7 @@ class Mapper(object):
         tups = []
         cached_connections = util.PopulateDict(
             lambda conn:conn.execution_options(
-            compiled_cache=self._compiled_cache.setdefault(conn.engine, {})
+            compiled_cache=self._compiled_cache
         ))
         
         for state in _sort_states(states):
index c2c85a4c93c26f4993a1a29b76c1721a18a332d2..4b04901a1a56088331d2b2cb7f3a3081ff09e127 100644 (file)
@@ -19,6 +19,7 @@ except ImportError:
 
 py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
 jython = sys.platform.startswith('java')
+win32 = sys.platform.startswith('win')
 
 if py3k:
     set_types = set
@@ -1542,6 +1543,54 @@ class WeakIdentityMapping(weakref.WeakKeyDictionary):
     def _ref(self, object):
         return self._keyed_weakref(object, self._cleanup)
 
+import time
+if win32 or jython:
+    time_func = time.clock
+else:
+    time_func = time.time 
+
+class LRUCache(dict):
+    def __init__(self, capacity=100, threshold=.5):
+        self.capacity = capacity
+        self.threshold = threshold
+
+    def __getitem__(self, key):
+        item = dict.__getitem__(self, key)
+        item[2] = time_func()
+        return item[1]
+
+    def values(self):
+        return [i[1] for i in dict.values(self)]
+
+    def setdefault(self, key, value):
+        if key in self:
+            return self[key]
+        else:
+            self[key] = value
+            return value
+
+    def __setitem__(self, key, value):
+        item = dict.get(self, key)
+        if item is None:
+            item = [key, value, time_func()]
+            dict.__setitem__(self, key, item)
+        else:
+            item[1] = value
+        self._manage_size()
+
+    def _manage_size(self):
+        while len(self) > self.capacity + self.capacity * self.threshold:
+            bytime = sorted(dict.values(self), 
+                            key=operator.itemgetter(2),
+                            reverse=True)
+            for item in bytime[self.capacity:]:
+                try:
+                    del self[item[0]]
+                except KeyError:
+                    # if we couldnt find a key, most 
+                    # likely some other thread broke in 
+                    # on us. loop around and try again
+                    break
 
 def warn(msg, stacklevel=3):
     if isinstance(msg, basestring):
index 2d64cd8046b4867a3dcf0043ca679d51eb1ad9cd..2e1810eca67c4838b65bf9817e7ea89e5638642e 100644 (file)
@@ -1,12 +1,13 @@
 from sqlalchemy.test.testing import eq_
-from sqlalchemy.orm import mapper, relationship, create_session, clear_mappers, \
-                            sessionmaker, class_mapper
+from sqlalchemy.orm import mapper, relationship, create_session,\
+                            clear_mappers, sessionmaker, class_mapper
 from sqlalchemy.orm.mapper import _mapper_registry
 from sqlalchemy.orm.session import _sessions
 from sqlalchemy.util import jython
 import operator
 from sqlalchemy.test import testing, engines
-from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType, create_engine
+from sqlalchemy import MetaData, Integer, String, ForeignKey, \
+                            PickleType, create_engine
 from sqlalchemy.test.schema import Table, Column
 import sqlalchemy as sa
 from sqlalchemy.sql import column
@@ -48,9 +49,11 @@ def profile_memory(func):
         else:
             flatline = True
 
-        if not flatline and samples[-1] > samples[0]:  # object count is bigger than when it started
+        # object count is bigger than when it started
+        if not flatline and samples[-1] > samples[0]:  
             for x in samples[1:-2]:
-                if x > samples[-1]:     # see if a spike bigger than the endpoint exists
+                # see if a spike bigger than the endpoint exists
+                if x > samples[-1]:
                     break
             else:
                 assert False, repr(samples) + " " + repr(flatline)
@@ -85,18 +88,21 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
         metadata.create_all()
 
         m1 = mapper(A, table1, properties={
-            "bs":relationship(B, cascade="all, delete", order_by=table2.c.col1)},
+            "bs":relationship(B, cascade="all, delete",
+                                    order_by=table2.c.col1)},
             order_by=table1.c.col1)
         m2 = mapper(B, table2)
 
@@ -139,26 +145,36 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                            test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                            test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
         metadata.create_all()
 
         m1 = mapper(A, table1, properties={
-            "bs":relationship(B, cascade="all, delete", order_by=table2.c.col1)},
-            order_by=table1.c.col1)
-        m2 = mapper(B, table2)
+            "bs":relationship(B, cascade="all, delete",
+                                    order_by=table2.c.col1)},
+            order_by=table1.c.col1,
+            _compiled_cache_size=10
+            )
+        m2 = mapper(B, table2,
+            _compiled_cache_size=10
+        )
 
         m3 = mapper(A, table1, non_primary=True)
 
         @profile_memory
         def go():
-            engine = engines.testing_engine(options={'logging_name':'FOO', 'pool_logging_name':'BAR'})
+            engine = engines.testing_engine(
+                                options={'logging_name':'FOO',
+                                        'pool_logging_name':'BAR'}
+                                    )
             sess = create_session(bind=engine)
             
             a1 = A(col2="a1")
@@ -192,15 +208,61 @@ class MemUsageTest(EnsureZeroed):
         del m1, m2, m3
         assert_no_mappers()
 
+    def test_many_updates(self):
+        metadata = MetaData(testing.db)
+        
+        wide_table = Table('t', metadata,
+            Column('id', Integer, primary_key=True),
+            *[Column('col%d' % i, Integer) for i in range(10)]
+        )
+        
+        class Wide(object):
+            pass
+        
+        mapper(Wide, wide_table, _compiled_cache_size=10)
+        
+        metadata.create_all()
+        session = create_session()
+        w1 = Wide()
+        session.add(w1)
+        session.flush()
+        session.close()
+        del session
+        counter = [1]
+        
+        @profile_memory
+        def go():
+            session = create_session()
+            w1 = session.query(Wide).first()
+            x = counter[0]
+            dec = 10
+            while dec > 0:
+                # trying to count in binary here, 
+                # works enough to trip the test case
+                if pow(2, dec) < x:
+                    setattr(w1, 'col%d' % dec, counter[0])
+                    x -= pow(2, dec)
+                dec -= 1
+            session.flush()
+            session.close()
+            counter[0] += 1
+            
+        try:
+            go()
+        finally:
+            metadata.drop_all()
+        
     def test_mapper_reset(self):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
@@ -251,7 +313,8 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('col2', String(30))
             )
 
@@ -311,12 +374,14 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('col2', String(30))
             )
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('col2', String(30)),
             )
 
@@ -333,7 +398,8 @@ class MemUsageTest(EnsureZeroed):
                 pass
 
             mapper(A, table1, properties={
-                'bs':relationship(B, secondary=table3, backref='as', order_by=table3.c.t1)
+                'bs':relationship(B, secondary=table3, 
+                                    backref='as', order_by=table3.c.t1)
             })
             mapper(B, table2)
 
@@ -381,12 +447,14 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("table1", metadata,
-            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('id', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('data', String(30))
             )
 
         table2 = Table("table2", metadata,
-            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('id', Integer, primary_key=True,
+                                    test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('t1id', Integer, ForeignKey('table1.id'))
             )
@@ -420,7 +488,8 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('col1', Integer, primary_key=True,
+                                test_needs_autoincrement=True),
             Column('col2', PickleType(comparator=operator.eq))
             )
         
index 035e4f2682e148248f075b576ac01a8338478dc9..f9888ef0c5c89529d8a68a3df6357749676975ec 100644 (file)
@@ -112,6 +112,52 @@ class ColumnCollectionTest(TestBase):
         assert (cc1==cc2).compare(c1 == c2)
         assert not (cc1==cc3).compare(c2 == c3)
 
+
+
+class LRUTest(TestBase):
+
+    def test_lru(self):                
+        class item(object):
+            def __init__(self, id):
+                self.id = id
+
+            def __str__(self):
+                return "item id %d" % self.id
+
+        l = util.LRUCache(10, threshold=.2)
+
+        for id in range(1,20):
+            l[id] = item(id)
+
+        # first couple of items should be gone
+        assert 1 not in l
+        assert 2 not in l
+
+        # next batch over the threshold of 10 should be present
+        for id_ in range(11,20):
+            assert id_ in l
+
+        l[12]
+        l[15]
+        l[23] = item(23)
+        l[24] = item(24)
+        l[25] = item(25)
+        l[26] = item(26)
+        l[27] = item(27)
+
+        assert 11 not in l
+        assert 13 not in l
+
+        for id_ in (25, 24, 23, 14, 12, 19, 18, 17, 16, 15):
+            assert id_ in l
+
+        i1 = l[25]
+        i2 = item(25)
+        l[25] = i2
+        assert 25 in l
+        assert l[25] is i2
+        
+
 class ImmutableSubclass(str):
     pass