]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-89039: Call subclass constructors in datetime.*.replace (GH-114780)
authorEugene Toder <eltoder@users.noreply.github.com>
Mon, 12 Feb 2024 12:44:56 +0000 (07:44 -0500)
committerGitHub <noreply@github.com>
Mon, 12 Feb 2024 12:44:56 +0000 (14:44 +0200)
When replace() method is called on a subclass of datetime, date or time,
properly call derived constructor. Previously, only the base class's
constructor was called.

Also, make sure to pass non-zero fold values when creating subclasses in
various methods. Previously, fold was silently ignored.

Lib/test/datetimetester.py
Misc/NEWS.d/next/Library/2023-12-18-20-10-50.gh-issue-89039.gqFdtU.rst [new file with mode: 0644]
Modules/_datetimemodule.c

index 980a8e6c1b183691550b8622b1fa2ef32b5245d0..31fc383e29707a717aa0833175d10d41e426a566 100644 (file)
@@ -1723,11 +1723,24 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
 
     def test_subclass_replace(self):
         class DateSubclass(self.theclass):
-            pass
+            def __new__(cls, *args, **kwargs):
+                result = self.theclass.__new__(cls, *args, **kwargs)
+                result.extra = 7
+                return result
 
         dt = DateSubclass(2012, 1, 1)
-        self.assertIs(type(dt.replace(year=2013)), DateSubclass)
-        self.assertIs(type(copy.replace(dt, year=2013)), DateSubclass)
+
+        test_cases = [
+            ('self.replace', dt.replace(year=2013)),
+            ('copy.replace', copy.replace(dt, year=2013)),
+        ]
+
+        for name, res in test_cases:
+            with self.subTest(name):
+                self.assertIs(type(res), DateSubclass)
+                self.assertEqual(res.year, 2013)
+                self.assertEqual(res.month, 1)
+                self.assertEqual(res.extra, 7)
 
     def test_subclass_date(self):
 
@@ -3025,6 +3038,26 @@ class TestDateTime(TestDate):
                 self.assertIsInstance(dt, DateTimeSubclass)
                 self.assertEqual(dt.extra, 7)
 
+    def test_subclass_replace_fold(self):
+        class DateTimeSubclass(self.theclass):
+            pass
+
+        dt = DateTimeSubclass(2012, 1, 1)
+        dt2 = DateTimeSubclass(2012, 1, 1, fold=1)
+
+        test_cases = [
+            ('self.replace', dt.replace(year=2013), 0),
+            ('self.replace', dt2.replace(year=2013), 1),
+            ('copy.replace', copy.replace(dt, year=2013), 0),
+            ('copy.replace', copy.replace(dt2, year=2013), 1),
+        ]
+
+        for name, res, fold in test_cases:
+            with self.subTest(name, fold=fold):
+                self.assertIs(type(res), DateTimeSubclass)
+                self.assertEqual(res.year, 2013)
+                self.assertEqual(res.fold, fold)
+
     def test_fromisoformat_datetime(self):
         # Test that isoformat() is reversible
         base_dates = [
@@ -3705,11 +3738,28 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
 
     def test_subclass_replace(self):
         class TimeSubclass(self.theclass):
-            pass
+            def __new__(cls, *args, **kwargs):
+                result = self.theclass.__new__(cls, *args, **kwargs)
+                result.extra = 7
+                return result
 
         ctime = TimeSubclass(12, 30)
-        self.assertIs(type(ctime.replace(hour=10)), TimeSubclass)
-        self.assertIs(type(copy.replace(ctime, hour=10)), TimeSubclass)
+        ctime2 = TimeSubclass(12, 30, fold=1)
+
+        test_cases = [
+            ('self.replace', ctime.replace(hour=10), 0),
+            ('self.replace', ctime2.replace(hour=10), 1),
+            ('copy.replace', copy.replace(ctime, hour=10), 0),
+            ('copy.replace', copy.replace(ctime2, hour=10), 1),
+        ]
+
+        for name, res, fold in test_cases:
+            with self.subTest(name, fold=fold):
+                self.assertIs(type(res), TimeSubclass)
+                self.assertEqual(res.hour, 10)
+                self.assertEqual(res.minute, 30)
+                self.assertEqual(res.extra, 7)
+                self.assertEqual(res.fold, fold)
 
     def test_subclass_time(self):
 
diff --git a/Misc/NEWS.d/next/Library/2023-12-18-20-10-50.gh-issue-89039.gqFdtU.rst b/Misc/NEWS.d/next/Library/2023-12-18-20-10-50.gh-issue-89039.gqFdtU.rst
new file mode 100644 (file)
index 0000000..d1998d7
--- /dev/null
@@ -0,0 +1,6 @@
+When replace() method is called on a subclass of datetime, date or time,
+properly call derived constructor. Previously, only the base class's
+constructor was called.
+
+Also, make sure to pass non-zero fold values when creating subclasses in
+various methods. Previously, fold was silently ignored.
index b984ea61b82f0fbcfe7d068acd77da1c07da7fa9..014ccdd3f6effe4a386f70eb7f286acb473ffe07 100644 (file)
@@ -1045,6 +1045,40 @@ new_datetime_ex(int year, int month, int day, int hour, int minute,
     new_datetime_ex2(y, m, d, hh, mm, ss, us, tzinfo, fold, \
                     &PyDateTime_DateTimeType)
 
+static PyObject *
+call_subclass_fold(PyObject *cls, int fold, const char *format, ...)
+{
+    PyObject *kwargs = NULL, *res = NULL;
+    va_list va;
+
+    va_start(va, format);
+    PyObject *args = Py_VaBuildValue(format, va);
+    va_end(va);
+    if (args == NULL) {
+        return NULL;
+    }
+    if (fold) {
+        kwargs = PyDict_New();
+        if (kwargs == NULL) {
+            goto Done;
+        }
+        PyObject *obj = PyLong_FromLong(fold);
+        if (obj == NULL) {
+            goto Done;
+        }
+        int err = PyDict_SetItemString(kwargs, "fold", obj);
+        Py_DECREF(obj);
+        if (err < 0) {
+            goto Done;
+        }
+    }
+    res = PyObject_Call(cls, args, kwargs);
+Done:
+    Py_DECREF(args);
+    Py_XDECREF(kwargs);
+    return res;
+}
+
 static PyObject *
 new_datetime_subclass_fold_ex(int year, int month, int day, int hour, int minute,
                               int second, int usecond, PyObject *tzinfo,
@@ -1054,17 +1088,11 @@ new_datetime_subclass_fold_ex(int year, int month, int day, int hour, int minute
         // Use the fast path constructor
         dt = new_datetime(year, month, day, hour, minute, second, usecond,
                           tzinfo, fold);
-    } else {
+    }
+    else {
         // Subclass
-        dt = PyObject_CallFunction(cls, "iiiiiiiO",
-                                   year,
-                                   month,
-                                   day,
-                                   hour,
-                                   minute,
-                                   second,
-                                   usecond,
-                                   tzinfo);
+        dt = call_subclass_fold(cls, fold, "iiiiiiiO", year, month, day,
+                                hour, minute, second, usecond, tzinfo);
     }
 
     return dt;
@@ -1120,6 +1148,24 @@ new_time_ex(int hour, int minute, int second, int usecond,
 #define new_time(hh, mm, ss, us, tzinfo, fold)                       \
     new_time_ex2(hh, mm, ss, us, tzinfo, fold, &PyDateTime_TimeType)
 
+static PyObject *
+new_time_subclass_fold_ex(int hour, int minute, int second, int usecond,
+                          PyObject *tzinfo, int fold, PyObject *cls)
+{
+    PyObject *t;
+    if ((PyTypeObject*)cls == &PyDateTime_TimeType) {
+        // Use the fast path constructor
+        t = new_time(hour, minute, second, usecond, tzinfo, fold);
+    }
+    else {
+        // Subclass
+        t = call_subclass_fold(cls, fold, "iiiiO", hour, minute, second,
+                               usecond, tzinfo);
+    }
+
+    return t;
+}
+
 /* Create a timedelta instance.  Normalize the members iff normalize is
  * true.  Passing false is a speed optimization, if you know for sure
  * that seconds and microseconds are already in their proper ranges.  In any
@@ -3480,7 +3526,7 @@ datetime_date_replace_impl(PyDateTime_Date *self, int year, int month,
                            int day)
 /*[clinic end generated code: output=2a9430d1e6318aeb input=0d1f02685b3e90f6]*/
 {
-    return new_date_ex(year, month, day, Py_TYPE(self));
+    return new_date_subclass_ex(year, month, day, (PyObject *)Py_TYPE(self));
 }
 
 static Py_hash_t
@@ -4589,8 +4635,8 @@ datetime_time_replace_impl(PyDateTime_Time *self, int hour, int minute,
                            int fold)
 /*[clinic end generated code: output=0b89a44c299e4f80 input=9b6a35b1e704b0ca]*/
 {
-    return new_time_ex2(hour, minute, second, microsecond, tzinfo, fold,
-                        Py_TYPE(self));
+    return new_time_subclass_fold_ex(hour, minute, second, microsecond, tzinfo,
+                                     fold, (PyObject *)Py_TYPE(self));
 }
 
 static PyObject *
@@ -6039,8 +6085,9 @@ datetime_datetime_replace_impl(PyDateTime_DateTime *self, int year,
                                int fold)
 /*[clinic end generated code: output=00bc96536833fddb input=9b38253d56d9bcad]*/
 {
-    return new_datetime_ex2(year, month, day, hour, minute, second,
-                            microsecond, tzinfo, fold, Py_TYPE(self));
+    return new_datetime_subclass_fold_ex(year, month, day, hour, minute,
+                                         second, microsecond, tzinfo, fold,
+                                         (PyObject *)Py_TYPE(self));
 }
 
 static PyObject *