[PATCH v2 07/13] REST: Make 'Patch.state' editable

Stephen Finucane stephen at that.guru
Sun Nov 20 03:51:22 AEDT 2016

This is one of the most useful fields to allow editing of via the API.
Make it so.

Signed-off-by: Stephen Finucane <stephen at that.guru>
Cc: Andy Doan <andy.doan at linaro.org>
 patchwork/api/__init__.py        |  5 +++++
 patchwork/api/patch.py           | 30 +++++++++++++++++++++++++-----
 patchwork/tests/test_rest_api.py |  6 ++++--
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/patchwork/api/__init__.py b/patchwork/api/__init__.py
index dbd8148..73a1dc7 100644
--- a/patchwork/api/__init__.py
+++ b/patchwork/api/__init__.py
@@ -23,6 +23,11 @@ from rest_framework import permissions
 from rest_framework.pagination import PageNumberPagination
 from rest_framework.response import Response
+from patchwork.models import State
+STATE_CHOICES = ['-'.join(state.name.lower().split())
+                 for state in State.objects.all()]
 class LinkHeaderPagination(PageNumberPagination):
     """Provide pagination based on rfc5988.
diff --git a/patchwork/api/patch.py b/patchwork/api/patch.py
index 737ada1..58fd843 100644
--- a/patchwork/api/patch.py
+++ b/patchwork/api/patch.py
@@ -20,25 +20,45 @@
 import email.parser
 from django.core.urlresolvers import reverse
-from rest_framework.serializers import HyperlinkedModelSerializer
+from rest_framework.exceptions import ValidationError
 from rest_framework.generics import ListAPIView
 from rest_framework.generics import RetrieveUpdateAPIView
+from rest_framework.serializers import ChoiceField
+from rest_framework.serializers import HyperlinkedModelSerializer
 from rest_framework.serializers import SerializerMethodField
 from patchwork.api import PatchworkPermission
+from patchwork.api import STATE_CHOICES
 from patchwork.models import Patch
+from patchwork.models import State
+class StateField(ChoiceField):
+    """Avoid the need for a state endpoint."""
+    def __init__(self, *args, **kwargs):
+        kwargs['choices'] = STATE_CHOICES
+        super(StateField, self).__init__(*args, **kwargs)
+    def to_internal_value(self, data):
+        data = ' '.join(data.split('-'))
+        try:
+            return State.objects.get(name__iexact=data)
+        except State.DoesNotExist:
+            raise ValidationError('Invalid state. Expected one of: %s ' %
+                                  ', '.join(STATE_CHOICES))
+    def to_representation(self, obj):
+        return '-'.join(obj.name.lower().split())
 class PatchListSerializer(HyperlinkedModelSerializer):
     mbox = SerializerMethodField()
-    state = SerializerMethodField()
+    state = StateField()
     tags = SerializerMethodField()
     check = SerializerMethodField()
     checks = SerializerMethodField()
-    def get_state(self, instance):
-        return instance.state.name
     def get_mbox(self, instance):
         request = self.context.get('request')
         return request.build_absolute_uri(instance.get_mbox_url())
diff --git a/patchwork/tests/test_rest_api.py b/patchwork/tests/test_rest_api.py
index 469fd26..e8eb71f 100644
--- a/patchwork/tests/test_rest_api.py
+++ b/patchwork/tests/test_rest_api.py
@@ -31,6 +31,7 @@ from patchwork.tests.utils import create_maintainer
 from patchwork.tests.utils import create_patch
 from patchwork.tests.utils import create_person
 from patchwork.tests.utils import create_project
+from patchwork.tests.utils import create_state
 from patchwork.tests.utils import create_user
 if settings.ENABLE_REST_API:
@@ -368,11 +369,12 @@ class TestPatchAPI(APITestCase):
         # A maintainer can update
         project = create_project()
         patch = create_patch(project=project)
+        state = create_state()
         user = create_maintainer(project)
         resp = self.client.patch(self.api_url(patch.id),
-                                 {'state': 2})
+                                 {'state': state.name})
         self.assertEqual(status.HTTP_200_OK, resp.status_code)
         # A normal user can't
@@ -380,7 +382,7 @@ class TestPatchAPI(APITestCase):
         resp = self.client.patch(self.api_url(patch.id),
-                                 {'state': 2})
+                                 {'state': state.name})
         self.assertEqual(status.HTTP_403_FORBIDDEN, resp.status_code)
     def test_delete(self):

