]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Change shelve to require a bytes-oriented dict as
authorMartin v. Löwis <martin@v.loewis.de>
Sat, 11 Aug 2007 06:57:14 +0000 (06:57 +0000)
committerMartin v. Löwis <martin@v.loewis.de>
Sat, 11 Aug 2007 06:57:14 +0000 (06:57 +0000)
the underlying storage, and yet provide string keys.

Lib/shelve.py
Lib/test/test_shelve.py

index 5759d4ec63ce15246467fb0e19f48c7610862011..586d2535d840a861669da47ac812fa5620f1e3ba 100644 (file)
@@ -71,25 +71,28 @@ class Shelf(UserDict.DictMixin):
     See the module's __doc__ string for an overview of the interface.
     """
 
-    def __init__(self, dict, protocol=None, writeback=False):
+    def __init__(self, dict, protocol=None, writeback=False,
+                 keyencoding="utf-8"):
         self.dict = dict
         if protocol is None:
             protocol = 0
         self._protocol = protocol
         self.writeback = writeback
         self.cache = {}
+        self.keyencoding = "utf-8"
 
     def keys(self):
-        return self.dict.keys()
+        for k in self.dict.keys():
+            yield k.decode(self.keyencoding)
 
     def __len__(self):
         return len(self.dict)
 
     def __contains__(self, key):
-        return key in self.dict
+        return key.encode(self.keyencoding) in self.dict
 
     def get(self, key, default=None):
-        if key in self.dict:
+        if key.encode(self.keyencoding) in self.dict:
             return self[key]
         return default
 
@@ -97,7 +100,7 @@ class Shelf(UserDict.DictMixin):
         try:
             value = self.cache[key]
         except KeyError:
-            f = BytesIO(self.dict[key])
+            f = BytesIO(self.dict[key.encode(self.keyencoding)])
             value = Unpickler(f).load()
             if self.writeback:
                 self.cache[key] = value
@@ -109,10 +112,10 @@ class Shelf(UserDict.DictMixin):
         f = BytesIO()
         p = Pickler(f, self._protocol)
         p.dump(value)
-        self.dict[key] = f.getvalue()
+        self.dict[key.encode(self.keyencoding)] = f.getvalue()
 
     def __delitem__(self, key):
-        del self.dict[key]
+        del self.dict[key.encode(self.keyencoding)]
         try:
             del self.cache[key]
         except KeyError:
@@ -156,33 +159,34 @@ class BsdDbShelf(Shelf):
     See the module's __doc__ string for an overview of the interface.
     """
 
-    def __init__(self, dict, protocol=None, writeback=False):
-        Shelf.__init__(self, dict, protocol, writeback)
+    def __init__(self, dict, protocol=None, writeback=False,
+                 keyencoding="utf-8"):
+        Shelf.__init__(self, dict, protocol, writeback, keyencoding)
 
     def set_location(self, key):
         (key, value) = self.dict.set_location(key)
         f = BytesIO(value)
-        return (key, Unpickler(f).load())
+        return (key.decode(self.keyencoding), Unpickler(f).load())
 
     def next(self):
         (key, value) = next(self.dict)
         f = BytesIO(value)
-        return (key, Unpickler(f).load())
+        return (key.decode(self.keyencoding), Unpickler(f).load())
 
     def previous(self):
         (key, value) = self.dict.previous()
         f = BytesIO(value)
-        return (key, Unpickler(f).load())
+        return (key.decode(self.keyencoding), Unpickler(f).load())
 
     def first(self):
         (key, value) = self.dict.first()
         f = BytesIO(value)
-        return (key, Unpickler(f).load())
+        return (key.decode(self.keyencoding), Unpickler(f).load())
 
     def last(self):
         (key, value) = self.dict.last()
         f = BytesIO(value)
-        return (key, Unpickler(f).load())
+        return (key.decode(self.keyencoding), Unpickler(f).load())
 
 
 class DbfilenameShelf(Shelf):
index 802462c41121f1b59f98a1f739e0cdaf4ee133d2..543afb1fabcc3653c1dd73b8cf0a9add1d5b8e27 100644 (file)
@@ -2,6 +2,36 @@ import unittest
 import shelve
 import glob
 from test import test_support
+from UserDict import DictMixin
+
+def L1(s):
+    return s.decode("latin-1")
+
+class byteskeydict(DictMixin):
+    "Mapping that supports bytes keys"
+
+    def __init__(self):
+        self.d = {}
+
+    def __getitem__(self, key):
+        return self.d[L1(key)]
+
+    def __setitem__(self, key, value):
+        self.d[L1(key)] = value
+
+    def __delitem__(self, key):
+        del self.d[L1(key)]
+
+    def iterkeys(self):
+        for k in self.d.keys():
+            yield k.decode("latin-1")
+
+    def keys(self):
+        return list(self.iterkeys())
+
+    def copy(self):
+        return byteskeydict(self.d)
+
 
 class TestCase(unittest.TestCase):
 
@@ -36,12 +66,12 @@ class TestCase(unittest.TestCase):
             s.close()
 
     def test_in_memory_shelf(self):
-        d1 = {}
+        d1 = byteskeydict()
         s = shelve.Shelf(d1, protocol=0)
         s['key1'] = (1,2,3,4)
         self.assertEqual(s['key1'], (1,2,3,4))
         s.close()
-        d2 = {}
+        d2 = byteskeydict()
         s = shelve.Shelf(d2, protocol=1)
         s['key1'] = (1,2,3,4)
         self.assertEqual(s['key1'], (1,2,3,4))
@@ -51,7 +81,7 @@ class TestCase(unittest.TestCase):
         self.assertNotEqual(d1, d2)
 
     def test_mutable_entry(self):
-        d1 = {}
+        d1 = byteskeydict()
         s = shelve.Shelf(d1, protocol=2, writeback=False)
         s['key1'] = [1,2,3,4]
         self.assertEqual(s['key1'], [1,2,3,4])
@@ -59,7 +89,7 @@ class TestCase(unittest.TestCase):
         self.assertEqual(s['key1'], [1,2,3,4])
         s.close()
 
-        d2 = {}
+        d2 = byteskeydict()
         s = shelve.Shelf(d2, protocol=2, writeback=True)
         s['key1'] = [1,2,3,4]
         self.assertEqual(s['key1'], [1,2,3,4])
@@ -84,7 +114,7 @@ class TestShelveBase(mapping_tests.BasicTestMappingProtocol):
         return {"key1":"value1", "key2":2, "key3":(1,2,3)}
     def _empty_mapping(self):
         if self._in_mem:
-            x= shelve.Shelf({}, **self._args)
+            x= shelve.Shelf(byteskeydict(), **self._args)
         else:
             self.counter+=1
             x= shelve.open(self.fn+str(self.counter), **self._args)