]> git.ipfire.org Git - thirdparty/patchwork.git/commitdiff
parser: close a TOCTTOU bug on Person creation
authorDaniel Axtens <dja@axtens.net>
Sat, 17 Feb 2018 01:54:51 +0000 (12:54 +1100)
committerDaniel Axtens <daniel.axtens@canonical.com>
Tue, 6 Mar 2018 14:23:49 +0000 (01:23 +1100)
find_author looks up a person by email, and if they do not exist,
creates a Person model, which may be saved later if the message
contains something valuable.

Multiple simultaneous processes can race here: both can do the SELECT,
find there is no Person, and create the model. One will succeed in
saving, the other will get an IntegrityError.

Reduce the window by making find_author into get_or_create_author, and
plumb that through. (Remove a test that specifically required find_author
to *not* create).

More importantly, cover the case where we lose the race, by using
get_or_create which handles the race case, catching the IntegrityError
internally and fetching the winning Person model.

Reviewed-by: Andrew Donnellan <andrew.donnellan@au1.ibm.com>
[dja: post review cleanup of now-unused import]
Signed-off-by: Daniel Axtens <dja@axtens.net>
patchwork/parser.py
patchwork/tests/test_parser.py

index ef6e9d55ca83c7256ffde6d2e082de26110e2bbe..463f05177ecdd28f99af13c5eaa0809694a88281 100644 (file)
@@ -255,7 +255,7 @@ def _find_series_by_references(project, mail):
             continue
 
 
-def _find_series_by_markers(project, mail):
+def _find_series_by_markers(project, mail, author):
     """Find a patch's series using series markers and sender.
 
     Identify suitable series for a patch using a combination of the
@@ -276,7 +276,6 @@ def _find_series_by_markers(project, mail):
     still won't help us if someone spams the mailing list with
     duplicate series but that's a tricky situation for anyone to parse.
     """
-    author = find_author(mail)
 
     subject = mail.get('Subject')
     name, prefixes = clean_subject(subject, [project.linkname])
@@ -296,7 +295,7 @@ def _find_series_by_markers(project, mail):
         return
 
 
-def find_series(project, mail):
+def find_series(project, mail, author):
     """Find a series, if any, for a given patch.
 
     Args:
@@ -311,10 +310,10 @@ def find_series(project, mail):
     if series:
         return series
 
-    return _find_series_by_markers(project, mail)
+    return _find_series_by_markers(project, mail, author)
 
 
-def find_author(mail):
+def get_or_create_author(mail):
     from_header = clean_header(mail.get('From'))
 
     if not from_header:
@@ -355,12 +354,17 @@ def find_author(mail):
     if name is not None:
         name = name.strip()[:255]
 
-    try:
-        person = Person.objects.get(email__iexact=email)
-        if name:  # use the latest provided name
-            person.name = name
-    except Person.DoesNotExist:
-        person = Person(name=name, email=email)
+    # this correctly handles the case where we lose the race to create
+    # the person and another process beats us to it. (If the record
+    # does not exist, g_o_c invokes _create_object_from_params which
+    # catches the IntegrityError and repeats the SELECT.)
+    person = Person.objects.get_or_create(email__iexact=email,
+                                          defaults={'name': name,
+                                                    'email': email})[0]
+
+    if name:  # use the latest provided name
+        person.name = name
+        person.save()
 
     return person
 
@@ -958,7 +962,6 @@ def parse_mail(mail, list_id=None):
         raise ValueError("Broken 'Message-Id' header")
     msgid = msgid[:255]
 
-    author = find_author(mail)
     subject = mail.get('Subject')
     name, prefixes = clean_subject(subject, [project.linkname])
     is_comment = subject_check(subject)
@@ -984,7 +987,7 @@ def parse_mail(mail, list_id=None):
 
     if not is_comment and (diff or pull_url):  # patches or pull requests
         # we delay the saving until we know we have a patch.
-        author.save()
+        author = get_or_create_author(mail)
 
         delegate = find_delegate_by_header(mail)
         if not delegate and diff:
@@ -995,7 +998,7 @@ def parse_mail(mail, list_id=None):
         # series to match against.
         series = None
         if n:
-            series = find_series(project, mail)
+            series = find_series(project, mail, author)
         else:
             x = n = 1
 
@@ -1072,7 +1075,7 @@ def parse_mail(mail, list_id=None):
                 is_cover_letter = True
 
         if is_cover_letter:
-            author.save()
+            author = get_or_create_author(mail)
 
             # we don't use 'find_series' here as a cover letter will
             # always be the first item in a thread, thus the references
@@ -1120,7 +1123,7 @@ def parse_mail(mail, list_id=None):
     if not submission:
         return
 
-    author.save()
+    author = get_or_create_author(mail)
 
     comment = Comment(
         submission=submission,
index 738fad7677ea87478964ba8b97b2c30f91729521..5ba06c0f3062f411ee918b01c1fffc74e8309928 100644 (file)
@@ -34,7 +34,7 @@ from patchwork.models import Patch
 from patchwork.models import Person
 from patchwork.models import State
 from patchwork.parser import clean_subject
-from patchwork.parser import find_author
+from patchwork.parser import get_or_create_author
 from patchwork.parser import find_patch_content as find_content
 from patchwork.parser import find_project
 from patchwork.parser import find_series
@@ -225,7 +225,7 @@ class SenderEncodingTest(TestCase):
 
     def _test_encoding(self, from_header, sender_name, sender_email):
         email = self._create_email(from_header)
-        person = find_author(email)
+        person = get_or_create_author(email)
         person.save()
 
         # ensure it was parsed correctly
@@ -241,7 +241,7 @@ class SenderEncodingTest(TestCase):
     def test_empty(self):
         email = self._create_email('')
         with self.assertRaises(ValueError):
-            find_author(email)
+            get_or_create_author(email)
 
     def test_ascii_encoding(self):
         from_header = 'example user <user@example.com>'
@@ -269,7 +269,7 @@ class SenderEncodingTest(TestCase):
 
 
 class SenderCorrelationTest(TestCase):
-    """Validate correct behavior of the find_author case.
+    """Validate correct behavior of the get_or_create_author case.
 
     Relies of checking the internal state of a Django model object.
 
@@ -284,25 +284,16 @@ class SenderCorrelationTest(TestCase):
                'test\n'
         return message_from_string(mail)
 
-    def test_non_existing_sender(self):
-        sender = 'Non-existing Sender <nonexisting@example.com>'
-        mail = self._create_email(sender)
-
-        # don't create the person - attempt to find immediately
-        person = find_author(mail)
-        self.assertEqual(person._state.adding, True)
-        self.assertEqual(person.id, None)
-
     def test_existing_sender(self):
         sender = 'Existing Sender <existing@example.com>'
         mail = self._create_email(sender)
 
         # create the person first
-        person_a = find_author(mail)
+        person_a = get_or_create_author(mail)
         person_a.save()
 
         # then attempt to parse email with the same 'From' line
-        person_b = find_author(mail)
+        person_b = get_or_create_author(mail)
         self.assertEqual(person_b._state.adding, False)
         self.assertEqual(person_b.id, person_a.id)
 
@@ -311,12 +302,12 @@ class SenderCorrelationTest(TestCase):
         mail = self._create_email(sender)
 
         # create the person first
-        person_a = find_author(mail)
+        person_a = get_or_create_author(mail)
         person_a.save()
 
         # then attempt to parse email with a new 'From' line
         mail = self._create_email('existing@example.com')
-        person_b = find_author(mail)
+        person_b = get_or_create_author(mail)
         self.assertEqual(person_b._state.adding, False)
         self.assertEqual(person_b.id, person_a.id)
 
@@ -324,11 +315,11 @@ class SenderCorrelationTest(TestCase):
         sender = 'Existing Sender <existing@example.com>'
         mail = self._create_email(sender)
 
-        person_a = find_author(mail)
+        person_a = get_or_create_author(mail)
         person_a.save()
 
         mail = self._create_email(sender.upper())
-        person_b = find_author(mail)
+        person_b = get_or_create_author(mail)
         self.assertEqual(person_b._state.adding, False)
         self.assertEqual(person_b.id, person_a.id)
 
@@ -361,7 +352,8 @@ class SeriesCorrelationTest(TestCase):
         email = self._create_email(msgid)
         project = create_project()
 
-        self.assertIsNone(find_series(project, email))
+        self.assertIsNone(find_series(project, email,
+                                      get_or_create_author(email)))
 
     def test_first_reply(self):
         msgid_a = make_msgid()
@@ -371,7 +363,8 @@ class SeriesCorrelationTest(TestCase):
         # assume msgid_a was already handled
         ref = create_series_reference(msgid=msgid_a)
 
-        series = find_series(ref.series.project, email)
+        series = find_series(ref.series.project, email,
+                             get_or_create_author(email))
         self.assertEqual(series, ref.series)
 
     def test_nested_series(self):
@@ -395,7 +388,7 @@ class SeriesCorrelationTest(TestCase):
         # ...and the "first patch" of this new series
         msgid = make_msgid()
         email = self._create_email(msgid, msgids)
-        series = find_series(project, email)
+        series = find_series(project, email, get_or_create_author(email))
 
         # this should link to the second series - not the first
         self.assertEqual(len(msgids), 4 + 1)  # old series + new cover