]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-97982: Factorize PyUnicode_Count() and unicode_count() code (#98025)
authorNikita Sobolev <mail@sobolevn.me>
Wed, 12 Oct 2022 16:27:53 +0000 (19:27 +0300)
committerGitHub <noreply@github.com>
Wed, 12 Oct 2022 16:27:53 +0000 (18:27 +0200)
Add unicode_count_impl() to factorize PyUnicode_Count()
and unicode_count() code.

Lib/test/test_unicode.py
Objects/unicodeobject.c

index 5b816574f2c33541e26609a7e329d89ef45fca6a..15244cb949e3db29c6170536f0f35f2c109bdd0e 100644 (file)
@@ -241,6 +241,10 @@ class UnicodeTest(string_tests.CommonTest,
         self.checkequal(0, 'a' * 10, 'count', 'a\u0102')
         self.checkequal(0, 'a' * 10, 'count', 'a\U00100304')
         self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304')
+        # test subclass
+        class MyStr(str):
+            pass
+        self.checkequal(3, MyStr('aaa'), 'count', 'a')
 
     def test_find(self):
         string_tests.CommonTest.test_find(self)
@@ -3002,6 +3006,12 @@ class CAPITest(unittest.TestCase):
                 self.assertEqual(unicode_count(uni, ch, 0, len(uni)), 1)
                 self.assertEqual(unicode_count(st, ch, 0, len(st)), 0)
 
+        # subclasses should still work
+        class MyStr(str):
+            pass
+
+        self.assertEqual(unicode_count(MyStr('aab'), 'a', 0, 3), 2)
+
     # Test PyUnicode_FindChar()
     @support.cpython_only
     @unittest.skipIf(_testcapi is None, 'need _testcapi module')
index 51e660afba083af39f0158cc098027d86e45c599..5b737f1b596e3e1c44898289c46b22f6e9ada692 100644 (file)
@@ -8964,21 +8964,20 @@ _PyUnicode_InsertThousandsGrouping(
     return count;
 }
 
-
-Py_ssize_t
-PyUnicode_Count(PyObject *str,
-                PyObject *substr,
-                Py_ssize_t start,
-                Py_ssize_t end)
+static Py_ssize_t
+unicode_count_impl(PyObject *str,
+                   PyObject *substr,
+                   Py_ssize_t start,
+                   Py_ssize_t end)
 {
+    assert(PyUnicode_Check(str));
+    assert(PyUnicode_Check(substr));
+
     Py_ssize_t result;
     int kind1, kind2;
     const void *buf1 = NULL, *buf2 = NULL;
     Py_ssize_t len1, len2;
 
-    if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
-        return -1;
-
     kind1 = PyUnicode_KIND(str);
     kind2 = PyUnicode_KIND(substr);
     if (kind1 < kind2)
@@ -8998,6 +8997,7 @@ PyUnicode_Count(PyObject *str,
             goto onError;
     }
 
+    // We don't reuse `anylib_count` here because of the explicit casts.
     switch (kind1) {
     case PyUnicode_1BYTE_KIND:
         result = ucs1lib_count(
@@ -9033,6 +9033,18 @@ PyUnicode_Count(PyObject *str,
     return -1;
 }
 
+Py_ssize_t
+PyUnicode_Count(PyObject *str,
+                PyObject *substr,
+                Py_ssize_t start,
+                Py_ssize_t end)
+{
+    if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
+        return -1;
+
+    return unicode_count_impl(str, substr, start, end);
+}
+
 Py_ssize_t
 PyUnicode_Find(PyObject *str,
                PyObject *substr,
@@ -10848,62 +10860,16 @@ unicode_count(PyObject *self, PyObject *args)
     PyObject *substring = NULL;   /* initialize to fix a compiler warning */
     Py_ssize_t start = 0;
     Py_ssize_t end = PY_SSIZE_T_MAX;
-    PyObject *result;
-    int kind1, kind2;
-    const void *buf1, *buf2;
-    Py_ssize_t len1, len2, iresult;
+    Py_ssize_t result;
 
     if (!parse_args_finds_unicode("count", args, &substring, &start, &end))
         return NULL;
 
-    kind1 = PyUnicode_KIND(self);
-    kind2 = PyUnicode_KIND(substring);
-    if (kind1 < kind2)
-        return PyLong_FromLong(0);
-
-    len1 = PyUnicode_GET_LENGTH(self);
-    len2 = PyUnicode_GET_LENGTH(substring);
-    ADJUST_INDICES(start, end, len1);
-    if (end - start < len2)
-        return PyLong_FromLong(0);
-
-    buf1 = PyUnicode_DATA(self);
-    buf2 = PyUnicode_DATA(substring);
-    if (kind2 != kind1) {
-        buf2 = unicode_askind(kind2, buf2, len2, kind1);
-        if (!buf2)
-            return NULL;
-    }
-    switch (kind1) {
-    case PyUnicode_1BYTE_KIND:
-        iresult = ucs1lib_count(
-            ((const Py_UCS1*)buf1) + start, end - start,
-            buf2, len2, PY_SSIZE_T_MAX
-            );
-        break;
-    case PyUnicode_2BYTE_KIND:
-        iresult = ucs2lib_count(
-            ((const Py_UCS2*)buf1) + start, end - start,
-            buf2, len2, PY_SSIZE_T_MAX
-            );
-        break;
-    case PyUnicode_4BYTE_KIND:
-        iresult = ucs4lib_count(
-            ((const Py_UCS4*)buf1) + start, end - start,
-            buf2, len2, PY_SSIZE_T_MAX
-            );
-        break;
-    default:
-        Py_UNREACHABLE();
-    }
-
-    result = PyLong_FromSsize_t(iresult);
-
-    assert((kind2 == kind1) == (buf2 == PyUnicode_DATA(substring)));
-    if (kind2 != kind1)
-        PyMem_Free((void *)buf2);
+    result = unicode_count_impl(self, substring, start, end);
+    if (result == -1)
+        return NULL;
 
-    return result;
+    return PyLong_FromSsize_t(result);
 }
 
 /*[clinic input]