]> git.ipfire.org Git - thirdparty/patchwork.git/commitdiff
REST: Don't iterate through models at class level
authorStephen Finucane <stephen@that.guru>
Tue, 7 Feb 2017 10:59:11 +0000 (10:59 +0000)
committerStephen Finucane <stephen@that.guru>
Wed, 8 Feb 2017 10:03:14 +0000 (10:03 +0000)
This causes two issues. Firstly, on fresh installs you see the following
error message:

  "Table 'patchwork.patchwork_state' doesn't exist"

Secondly, any new states created when the process is running will not be
reflected in the API until the server process is restarted.

Resolve this issue by moving the step into a method, thus ensuring it's
continuously refreshed. It doesn't seem possible to add tests to prevent
this regressing but some similarly useful tests are included to at least
validate the behavior of that field.

Signed-off-by: Stephen Finucane <stephen@that.guru>
Tested-By: Denis Laxalde <denis@laxalde.org>
Fixes: a2993505 ("REST: Make 'Patch.state' editable")
Closes-bug: #67
Closes-bug: #80

patchwork/api/base.py
patchwork/api/patch.py
patchwork/tests/test_rest_api.py

index 13a84322580a32f3b55e5c53323effdc6f5abeb7..dbd8148d033f3563162ae5c1865dd55f68051b5d 100644 (file)
@@ -23,11 +23,6 @@ from rest_framework import permissions
 from rest_framework.pagination import PageNumberPagination
 from rest_framework.response import Response
 
-from patchwork.models import State
-
-STATE_CHOICES = ['-'.join(x.name.lower().split(' '))
-                 for x in State.objects.all()]
-
 
 class LinkHeaderPagination(PageNumberPagination):
     """Provide pagination based on rfc5988.
index e8fb0ef3cd804e038c202e66f8e24101c8aadf4b..1a7be584d0beff51112a150f437ab20de950b81b 100644 (file)
 import email.parser
 
 from django.core.urlresolvers import reverse
-from rest_framework.exceptions import ValidationError
+from django.utils.translation import ugettext_lazy as _
 from rest_framework.generics import ListAPIView
 from rest_framework.generics import RetrieveUpdateAPIView
-from rest_framework.serializers import ChoiceField
+from rest_framework.relations import RelatedField
 from rest_framework.serializers import HyperlinkedModelSerializer
 from rest_framework.serializers import SerializerMethodField
 
 from patchwork.api.base import PatchworkPermission
-from patchwork.api.base import STATE_CHOICES
 from patchwork.api.filters import PatchFilter
 from patchwork.models import Patch
 from patchwork.models import State
 
 
-class StateField(ChoiceField):
-    """Avoid the need for a state endpoint."""
+def format_state_name(state):
+    return ' '.join(state.split('-'))
 
-    def __init__(self, *args, **kwargs):
-        kwargs['choices'] = STATE_CHOICES
-        super(StateField, self).__init__(*args, **kwargs)
+
+class StateField(RelatedField):
+    """Avoid the need for a state endpoint.
+
+    NOTE(stephenfin): This field will only function for State names consisting
+    of alphanumeric characters, underscores and single spaces. In Patchwork
+    2.0+, we should consider adding a slug field to the State object and make
+    use of the SlugRelatedField in DRF.
+    """
+    default_error_messages = {
+        'required': _('This field is required.'),
+        'invalid_choice': _('Invalid state {name}. Expected one of: '
+                            '{choices}.'),
+        'incorrect_type': _('Incorrect type. Expected string value, received '
+                            '{data_type}.'),
+    }
 
     def to_internal_value(self, data):
-        data = ' '.join(data.split('-'))
         try:
-            return State.objects.get(name__iexact=data)
+            data = format_state_name(data)
+            return self.get_queryset().get(name__iexact=data)
         except State.DoesNotExist:
-            raise ValidationError('Invalid state. Expected one of: %s ' %
-                                  ', '.join(STATE_CHOICES))
+            self.fail('invalid_choice', name=data, choices=', '.join([
+                format_state_name(x.name) for x in self.get_queryset()]))
+        except (TypeError, ValueError):
+            self.fail('incorrect_type', data_type=type(data).__name__)
 
     def to_representation(self, obj):
         return '-'.join(obj.name.lower().split())
 
+    def get_queryset(self):
+        return State.objects.all()
+
 
 class PatchListSerializer(HyperlinkedModelSerializer):
     mbox = SerializerMethodField()
index cc1fcef0ac9edc8919d6d4711ee7aed948df2a6c..b6e61440baece97c2fb219592fc5730ed1b090f3 100644 (file)
@@ -398,6 +398,20 @@ class TestPatchAPI(APITestCase):
         self.assertEqual(status.HTTP_200_OK, resp.status_code)
         self.assertEqual(Patch.objects.get(id=patch.id).state, state)
 
+    def test_update_invalid(self):
+        """Ensure we handle invalid Patch states."""
+        project = create_project()
+        state = create_state()
+        patch = create_patch(project=project, state=state)
+        user = create_maintainer(project)
+
+        # invalid state
+        self.client.force_authenticate(user=user)
+        resp = self.client.patch(self.api_url(patch.id), {'state': 'foobar'})
+        self.assertEqual(status.HTTP_400_BAD_REQUEST, resp.status_code)
+        self.assertContains(resp, 'Expected one of: %s.' % state.name,
+                            status_code=status.HTTP_400_BAD_REQUEST)
+
     def test_delete(self):
         """Ensure deletions are always rejected."""
         project = create_project()