]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Make subclasses of int, long, complex, float, and unicode perform type
authorBrett Cannon <bcannon@gmail.com>
Tue, 26 Apr 2005 03:45:26 +0000 (03:45 +0000)
committerBrett Cannon <bcannon@gmail.com>
Tue, 26 Apr 2005 03:45:26 +0000 (03:45 +0000)
conversion using the proper magic slot (e.g., __int__()).  Also move conversion
code out of PyNumber_*() functions in the C API into the nb_* function.

Applied patch #1109424.  Thanks Walter Doewald.

Lib/test/test_builtin.py
Lib/test/test_complex.py
Lib/test/test_str.py
Lib/test/test_unicode.py
Misc/NEWS
Objects/abstract.c
Objects/floatobject.c
Objects/intobject.c
Objects/longobject.c
Objects/object.c

index 4e8ffe5337ffa3e62c03060c0b8b6ba8f2a6b59e..103d1a3ac626ab5e9ff0f97d222a11be60dfa4a8 100644 (file)
@@ -545,6 +545,37 @@ class BuiltinTest(unittest.TestCase):
             self.assertEqual(float(unicode("  3.14  ")), 3.14)
             self.assertEqual(float(unicode("  \u0663.\u0661\u0664  ",'raw-unicode-escape')), 3.14)
 
+    def test_floatconversion(self):
+        # Make sure that calls to __float__() work properly
+        class Foo0:
+            def __float__(self):
+                return 42.
+
+        class Foo1(object):
+            def __float__(self):
+                return 42.
+
+        class Foo2(float):
+            def __float__(self):
+                return 42.
+
+        class Foo3(float):
+            def __new__(cls, value=0.):
+                return float.__new__(cls, 2*value)
+
+            def __float__(self):
+                return self
+
+        class Foo4(float):
+            def __float__(self):
+                return 42
+
+        self.assertAlmostEqual(float(Foo0()), 42.)
+        self.assertAlmostEqual(float(Foo1()), 42.)
+        self.assertAlmostEqual(float(Foo2()), 42.)
+        self.assertAlmostEqual(float(Foo3(21)), 42.)
+        self.assertRaises(TypeError, float, Foo4(42))
+
     def test_getattr(self):
         import sys
         self.assert_(getattr(sys, 'stdout') is sys.stdout)
@@ -650,6 +681,39 @@ class BuiltinTest(unittest.TestCase):
 
         self.assertEqual(int('0123', 0), 83)
 
+    def test_intconversion(self):
+        # Test __int__()
+        class Foo0:
+            def __int__(self):
+                return 42
+
+        class Foo1(object):
+            def __int__(self):
+                return 42
+
+        class Foo2(int):
+            def __int__(self):
+                return 42
+
+        class Foo3(int):
+            def __int__(self):
+                return self
+
+        class Foo4(int):
+            def __int__(self):
+                return 42L
+
+        class Foo5(int):
+            def __int__(self):
+                return 42.
+
+        self.assertEqual(int(Foo0()), 42)
+        self.assertEqual(int(Foo1()), 42)
+        self.assertEqual(int(Foo2()), 42)
+        self.assertEqual(int(Foo3()), 0)
+        self.assertEqual(int(Foo4()), 42L)
+        self.assertRaises(TypeError, int, Foo5())
+
     def test_intern(self):
         self.assertRaises(TypeError, intern)
         s = "never interned before"
@@ -810,6 +874,39 @@ class BuiltinTest(unittest.TestCase):
         self.assertRaises(ValueError, long, '53', 40)
         self.assertRaises(TypeError, long, 1, 12)
 
+    def test_longconversion(self):
+        # Test __long__()
+        class Foo0:
+            def __long__(self):
+                return 42L
+
+        class Foo1(object):
+            def __long__(self):
+                return 42L
+
+        class Foo2(long):
+            def __long__(self):
+                return 42L
+
+        class Foo3(long):
+            def __long__(self):
+                return self
+
+        class Foo4(long):
+            def __long__(self):
+                return 42
+
+        class Foo5(long):
+            def __long__(self):
+                return 42.
+
+        self.assertEqual(long(Foo0()), 42L)
+        self.assertEqual(long(Foo1()), 42L)
+        self.assertEqual(long(Foo2()), 42L)
+        self.assertEqual(long(Foo3()), 0)
+        self.assertEqual(long(Foo4()), 42)
+        self.assertRaises(TypeError, long, Foo5())
+
     def test_map(self):
         self.assertEqual(
             map(None, 'hello world'),
index 15f4b654794e9172f7f772ca1585c5a268f4487a..70e91c1c9e3d4bfa169df98f44812d995ae28d4e 100644 (file)
@@ -273,6 +273,28 @@ class ComplexTest(unittest.TestCase):
         self.assertAlmostEqual(complex(real=float2(17.), imag=float2(23.)), 17+23j)
         self.assertRaises(TypeError, complex, float2(None))
 
+        class complex0(complex):
+            """Test usage of __complex__() when inheriting from 'complex'"""
+            def __complex__(self):
+                return 42j
+
+        class complex1(complex):
+            """Test usage of __complex__() with a __new__() method"""
+            def __new__(self, value=0j):
+                return complex.__new__(self, 2*value)
+            def __complex__(self):
+                return self
+
+        class complex2(complex):
+            """Make sure that __complex__() calls fail if anything other than a
+            complex is returned"""
+            def __complex__(self):
+                return None
+
+        self.assertAlmostEqual(complex(complex0(1j)), 42j)
+        self.assertAlmostEqual(complex(complex1(1j)), 2j)
+        self.assertRaises(TypeError, complex, complex2(1j))
+
     def test_hash(self):
         for x in xrange(-30, 30):
             self.assertEqual(hash(x), hash(complex(x, 0)))
index 82632f10ebb86c86f4c20b21be5de1c90fc45a4d..45942a66ef24ce5cf9c38f651e269de2844ae929 100644 (file)
@@ -19,6 +19,69 @@ class StrTest(
         string_tests.MixinStrUnicodeUserStringTest.test_formatting(self)
         self.assertRaises(OverflowError, '%c'.__mod__, 0x1234)
 
+    def test_conversion(self):
+        # Make sure __str__() behaves properly
+        class Foo0:
+            def __unicode__(self):
+                return u"foo"
+
+        class Foo1:
+            def __str__(self):
+                return "foo"
+
+        class Foo2(object):
+            def __str__(self):
+                return "foo"
+
+        class Foo3(object):
+            def __str__(self):
+                return u"foo"
+
+        class Foo4(unicode):
+            def __str__(self):
+                return u"foo"
+
+        class Foo5(str):
+            def __str__(self):
+                return u"foo"
+
+        class Foo6(str):
+            def __str__(self):
+                return "foos"
+
+            def __unicode__(self):
+                return u"foou"
+
+        class Foo7(unicode):
+            def __str__(self):
+                return "foos"
+            def __unicode__(self):
+                return u"foou"
+
+        class Foo8(str):
+            def __new__(cls, content=""):
+                return str.__new__(cls, 2*content)
+            def __str__(self):
+                return self
+
+        class Foo9(str):
+            def __str__(self):
+                return "string"
+            def __unicode__(self):
+                return "not unicode"
+
+        self.assert_(str(Foo0()).startswith("<")) # this is different from __unicode__
+        self.assertEqual(str(Foo1()), "foo")
+        self.assertEqual(str(Foo2()), "foo")
+        self.assertEqual(str(Foo3()), "foo")
+        self.assertEqual(str(Foo4("bar")), "foo")
+        self.assertEqual(str(Foo5("bar")), "foo")
+        self.assertEqual(str(Foo6("bar")), "foos")
+        self.assertEqual(str(Foo7("bar")), "foos")
+        self.assertEqual(str(Foo8("foo")), "foofoo")
+        self.assertEqual(str(Foo9("foo")), "string")
+        self.assertEqual(unicode(Foo9("foo")), u"not unicode")
+
 def test_main():
     test_support.run_unittest(StrTest)
 
index 69244f0d6f3d8c6102b0c5a8cf4528debddb9fc3..80242d519921d66513e2cb393905ab99c05e1636 100644 (file)
@@ -389,7 +389,6 @@ class UnicodeTest(
         self.assertEqual('%i%s %*.*s' % (10, 3, 5, 3, u'abc',), u'103   abc')
         self.assertEqual('%c' % u'a', u'a')
 
-
     def test_constructor(self):
         # unicode(obj) tests (this maps to PyObject_Unicode() at C level)
 
@@ -725,6 +724,69 @@ class UnicodeTest(
         y = x.encode("raw-unicode-escape").decode("raw-unicode-escape")
         self.assertEqual(x, y)
 
+    def test_conversion(self):
+        # Make sure __unicode__() works properly
+        class Foo0:
+            def __str__(self):
+                return "foo"
+
+        class Foo1:
+            def __unicode__(self):
+                return u"foo"
+
+        class Foo2(object):
+            def __unicode__(self):
+                return u"foo"
+
+        class Foo3(object):
+            def __unicode__(self):
+                return "foo"
+
+        class Foo4(str):
+            def __unicode__(self):
+                return "foo"
+
+        class Foo5(unicode):
+            def __unicode__(self):
+                return "foo"
+
+        class Foo6(str):
+            def __str__(self):
+                return "foos"
+
+            def __unicode__(self):
+                return u"foou"
+
+        class Foo7(unicode):
+            def __str__(self):
+                return "foos"
+            def __unicode__(self):
+                return u"foou"
+
+        class Foo8(unicode):
+            def __new__(cls, content=""):
+                return unicode.__new__(cls, 2*content)
+            def __unicode__(self):
+                return self
+
+        class Foo9(unicode):
+            def __str__(self):
+                return "string"
+            def __unicode__(self):
+                return "not unicode"
+
+        self.assertEqual(unicode(Foo0()), u"foo")
+        self.assertEqual(unicode(Foo1()), u"foo")
+        self.assertEqual(unicode(Foo2()), u"foo")
+        self.assertEqual(unicode(Foo3()), u"foo")
+        self.assertEqual(unicode(Foo4("bar")), u"foo")
+        self.assertEqual(unicode(Foo5("bar")), u"foo")
+        self.assertEqual(unicode(Foo6("bar")), u"foou")
+        self.assertEqual(unicode(Foo7("bar")), u"foou")
+        self.assertEqual(unicode(Foo8("foo")), u"foofoo")
+        self.assertEqual(str(Foo9("foo")), "string")
+        self.assertEqual(unicode(Foo9("foo")), u"not unicode")
+
 def test_main():
     test_support.run_unittest(UnicodeTest)
 
index 3c6b6c14841f7d77d1c3dcca9b84c8fd1f72cd86..8d70fce031d8e43972c85c1704d3696f285f0f0b 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,14 @@ What's New in Python 2.5 alpha 1?
 Core and builtins
 -----------------
 
+- patch #1109424: int, long, float, complex, and unicode now check for the
+  proper magic slot for type conversions when subclassed.  Previously the
+  magic slot was ignored during conversion.  Semantics now match the way
+  subclasses of str always behaved.  int/long/float, conversion of an instance
+  to the base class has been moved the prroper nb_* magic slot and out of
+  PyNumber_*().
+  Thanks Walter Dörwald.
+
 - Descriptors defined in C with a PyGetSetDef structure, where the setter is
   NULL, now raise an AttributeError when attempting to set or delete the
   attribute.  Previously a TypeError was raised, but this was inconsistent
index 875c8804f726c1e8be210c858f00dac702fd0d81..d28006a97d30911dd0afcbe965b8b4c6c7334a5f 100644 (file)
@@ -951,7 +951,19 @@ PyNumber_Int(PyObject *o)
                Py_INCREF(o);
                return o;
        }
-       if (PyInt_Check(o)) {
+       m = o->ob_type->tp_as_number;
+       if (m && m->nb_int) { /* This should include subclasses of int */
+               PyObject *res = m->nb_int(o);
+               if (res && (!PyInt_Check(res) && !PyLong_Check(res))) {
+                       PyErr_Format(PyExc_TypeError,
+                                    "__int__ returned non-int (type %.200s)",
+                                    res->ob_type->tp_name);
+                       Py_DECREF(res);
+                       return NULL;
+               }
+               return res;
+       }
+       if (PyInt_Check(o)) { /* A int subclass without nb_int */
                PyIntObject *io = (PyIntObject*)o;
                return PyInt_FromLong(io->ob_ival);
        }
@@ -964,18 +976,6 @@ PyNumber_Int(PyObject *o)
                                         PyUnicode_GET_SIZE(o),
                                         10);
 #endif
-       m = o->ob_type->tp_as_number;
-       if (m && m->nb_int) {
-               PyObject *res = m->nb_int(o);
-               if (res && (!PyInt_Check(res) && !PyLong_Check(res))) {
-                       PyErr_Format(PyExc_TypeError,
-                                    "__int__ returned non-int (type %.200s)",
-                                    res->ob_type->tp_name);
-                       Py_DECREF(res);
-                       return NULL;
-               }
-               return res;
-       }
        if (!PyObject_AsCharBuffer(o, &buffer, &buffer_len))
                return int_from_string((char*)buffer, buffer_len);
 
@@ -1010,11 +1010,19 @@ PyNumber_Long(PyObject *o)
 
        if (o == NULL)
                return null_error();
-       if (PyLong_CheckExact(o)) {
-               Py_INCREF(o);
-               return o;
+       m = o->ob_type->tp_as_number;
+       if (m && m->nb_long) { /* This should include subclasses of long */
+               PyObject *res = m->nb_long(o);
+               if (res && (!PyInt_Check(res) && !PyLong_Check(res))) {
+                       PyErr_Format(PyExc_TypeError,
+                                    "__long__ returned non-long (type %.200s)",
+                                    res->ob_type->tp_name);
+                       Py_DECREF(res);
+                       return NULL;
+               }
+               return res;
        }
-       if (PyLong_Check(o))
+       if (PyLong_Check(o)) /* A long subclass without nb_long */
                return _PyLong_Copy((PyLongObject *)o);
        if (PyString_Check(o))
                /* need to do extra error checking that PyLong_FromString()
@@ -1030,18 +1038,6 @@ PyNumber_Long(PyObject *o)
                                          PyUnicode_GET_SIZE(o),
                                          10);
 #endif
-       m = o->ob_type->tp_as_number;
-       if (m && m->nb_long) {
-               PyObject *res = m->nb_long(o);
-               if (res && (!PyInt_Check(res) && !PyLong_Check(res))) {
-                       PyErr_Format(PyExc_TypeError,
-                                    "__long__ returned non-long (type %.200s)",
-                                    res->ob_type->tp_name);
-                       Py_DECREF(res);
-                       return NULL;
-               }
-               return res;
-       }
        if (!PyObject_AsCharBuffer(o, &buffer, &buffer_len))
                return long_from_string(buffer, buffer_len);
 
@@ -1055,28 +1051,22 @@ PyNumber_Float(PyObject *o)
 
        if (o == NULL)
                return null_error();
-       if (PyFloat_CheckExact(o)) {
-               Py_INCREF(o);
-               return o;
+       m = o->ob_type->tp_as_number;
+       if (m && m->nb_float) { /* This should include subclasses of float */
+               PyObject *res = m->nb_float(o);
+               if (res && !PyFloat_Check(res)) {
+                       PyErr_Format(PyExc_TypeError,
+                         "__float__ returned non-float (type %.200s)",
+                         res->ob_type->tp_name);
+                       Py_DECREF(res);
+                       return NULL;
+               }
+               return res;
        }
-       if (PyFloat_Check(o)) {
+       if (PyFloat_Check(o)) { /* A float subclass with nb_float == NULL */
                PyFloatObject *po = (PyFloatObject *)o;
                return PyFloat_FromDouble(po->ob_fval);
        }
-       if (!PyString_Check(o)) {
-               m = o->ob_type->tp_as_number;
-               if (m && m->nb_float) {
-                       PyObject *res = m->nb_float(o);
-                       if (res && !PyFloat_Check(res)) {
-                               PyErr_Format(PyExc_TypeError,
-                                 "__float__ returned non-float (type %.200s)",
-                                 res->ob_type->tp_name);
-                               Py_DECREF(res);
-                               return NULL;
-                       }
-                       return res;
-               }
-       }
        return PyFloat_FromString(o, NULL);
 }
 
index 539c4a9f4ce21dedaaf6a1382affb966e9aeb763..55f43cb69b37600ee704ec2183386c4c7ac90d68 100644 (file)
@@ -926,7 +926,10 @@ float_int(PyObject *v)
 static PyObject *
 float_float(PyObject *v)
 {
-       Py_INCREF(v);
+       if (PyFloat_CheckExact(v))
+               Py_INCREF(v);
+       else
+               v = PyFloat_FromDouble(((PyFloatObject *)v)->ob_fval);
        return v;
 }
 
index 763ed53d4b9cd308df5e5389af530da9c0d05ebc..0ead74b0fb5e2e00f42701d9b06c2275fcab0434 100644 (file)
@@ -826,7 +826,10 @@ int_coerce(PyObject **pv, PyObject **pw)
 static PyObject *
 int_int(PyIntObject *v)
 {
-       Py_INCREF(v);
+       if (PyInt_CheckExact(v))
+               Py_INCREF(v);
+       else
+               v = (PyIntObject *)PyInt_FromLong(v->ob_ival);
        return (PyObject *)v;
 }
 
index 11a7024e4543eda49e84b8f7bfe7bc7062d530b4..e4fc553a80c6c258efa4195249770573074d75fc 100644 (file)
@@ -2861,7 +2861,10 @@ long_coerce(PyObject **pv, PyObject **pw)
 static PyObject *
 long_long(PyObject *v)
 {
-       Py_INCREF(v);
+       if (PyLong_CheckExact(v))
+               Py_INCREF(v);
+       else
+               v = _PyLong_Copy((PyLongObject *)v);
        return v;
 }
 
index d86d74f6d7659c14fcdd7eea0506156fc6fc6e17..975c967cca7d220c1463bf59a73840407a6c4e15 100644 (file)
@@ -373,6 +373,8 @@ PyObject *
 PyObject_Unicode(PyObject *v)
 {
        PyObject *res;
+       PyObject *func;
+       static PyObject *unicodestr;
 
        if (v == NULL)
                res = PyString_FromString("<NULL>");
@@ -380,35 +382,32 @@ PyObject_Unicode(PyObject *v)
                Py_INCREF(v);
                return v;
        }
-       if (PyUnicode_Check(v)) {
-               /* For a Unicode subtype that's not a Unicode object,
-                  return a true Unicode object with the same data. */
-               return PyUnicode_FromUnicode(PyUnicode_AS_UNICODE(v),
-                                            PyUnicode_GET_SIZE(v));
+       /* XXX As soon as we have a tp_unicode slot, we should
+          check this before trying the __unicode__
+          method. */
+       if (unicodestr == NULL) {
+               unicodestr= PyString_InternFromString("__unicode__");
+               if (unicodestr == NULL)
+                       return NULL;
+       }
+       func = PyObject_GetAttr(v, unicodestr);
+       if (func != NULL) {
+               res = PyEval_CallObject(func, (PyObject *)NULL);
+               Py_DECREF(func);
        }
-       if (PyString_Check(v)) {
-               Py_INCREF(v);
-               res = v;
-       }
        else {
-               PyObject *func;
-               static PyObject *unicodestr;
-               /* XXX As soon as we have a tp_unicode slot, we should
-                      check this before trying the __unicode__
-                      method. */
-               if (unicodestr == NULL) {
-                       unicodestr= PyString_InternFromString(
-                                                      "__unicode__");
-                       if (unicodestr == NULL)
-                               return NULL;
+               PyErr_Clear();
+               if (PyUnicode_Check(v)) {
+                       /* For a Unicode subtype that's didn't overwrite __unicode__,
+                          return a true Unicode object with the same data. */
+                       return PyUnicode_FromUnicode(PyUnicode_AS_UNICODE(v),
+                                                    PyUnicode_GET_SIZE(v));
                }
-               func = PyObject_GetAttr(v, unicodestr);
-               if (func != NULL) {
-                       res = PyEval_CallObject(func, (PyObject *)NULL);
-                       Py_DECREF(func);
+               if (PyString_CheckExact(v)) {
+                       Py_INCREF(v);
+                       res = v;
                }
                else {
-                       PyErr_Clear();
                        if (v->ob_type->tp_str != NULL)
                                res = (*v->ob_type->tp_str)(v);
                        else
@@ -424,7 +423,7 @@ PyObject_Unicode(PyObject *v)
                if (str)
                        res = str;
                else
-                       return NULL;
+                       return NULL;
        }
        return res;
 }