]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- symbols now depickle properly
authorJason Kirtland <jek@discorporate.us>
Wed, 19 Mar 2008 18:02:47 +0000 (18:02 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 19 Mar 2008 18:02:47 +0000 (18:02 +0000)
- fixed some symbol __new__ abuse

lib/sqlalchemy/util.py
test/base/utils.py

index 9adb3983db73bac0816c22c035358b9c879a609d..af77d792eec1b6c610a8d53838683e6077b24998 100644 (file)
@@ -981,6 +981,17 @@ class ScopedRegistry(object):
     def _get_key(self):
         return self.scopefunc()
 
+class _symbol(object):
+    def __init__(self, name):
+        """Construct a new named symbol."""
+        assert isinstance(name, str)
+        self.name = name
+    def __reduce__(self):
+        return symbol, (self.name,)
+    def __repr__(self):
+        return "<symbol '%s>" % self.name
+_symbol.__name__ = 'symbol'
+
 class symbol(object):
     """A constant symbol.
 
@@ -991,32 +1002,23 @@ class symbol(object):
 
     A slight refinement of the MAGICCOOKIE=object() pattern.  The primary
     advantage of symbol() is its repr().  They are also singletons.
-    """
 
+    Repeated calls of symbol('name') will all return the same instance.
+
+    """
     symbols = {}
     _lock = threading.Lock()
 
     def __new__(cls, name):
+        cls._lock.acquire()
         try:
-            symbol._lock.acquire()
             sym = cls.symbols.get(name)
             if sym is None:
-                cls.symbols[name] = sym = object.__new__(cls, name)
+                cls.symbols[name] = sym = _symbol(name)
             return sym
         finally:
             symbol._lock.release()
 
-    def __init__(self, name):
-        """Construct a new named symbol.
-
-        Repeated calls of symbol('name') will all return the same instance.
-        """
-
-        assert isinstance(name, str)
-        self.name = name
-    def __repr__(self):
-        return "<symbol '%s>" % self.name
-
 def warn(msg):
     if isinstance(msg, basestring):
         warnings.warn(msg, exceptions.SAWarning, stacklevel=3)
index 6ab141d6cbd74dc6be64d362531dd4ca5b95d53f..fc72cf8e12ff7c618453725c195f2d0f7d44a231 100644 (file)
@@ -389,5 +389,36 @@ class ArgInspectionTest(TestBase):
         test(f3)
         test(f4)
 
+class SymbolTest(TestBase):
+    def test_basic(self):
+        sym1 = util.symbol('foo')
+        assert sym1.name == 'foo'
+        sym2 = util.symbol('foo')
+
+        assert sym1 is sym2
+        assert sym1 == sym2
+
+        sym3 = util.symbol('bar')
+        assert sym1 is not sym3
+        assert sym1 != sym3
+
+    def test_pickle(self):
+        sym1 = util.symbol('foo')
+        sym2 = util.symbol('foo')
+
+        assert sym1 is sym2
+
+        # default
+        s = util.pickle.dumps(sym1)
+        sym3 = util.pickle.loads(s)
+
+        for protocol in 0, 1, 2:
+            print protocol
+            serial = util.pickle.dumps(sym1)
+            rt = util.pickle.loads(serial)
+            assert rt is sym1
+            assert rt is sym2
+
+
 if __name__ == "__main__":
     testenv.main()