]> git.ipfire.org Git - thirdparty/patchwork.git/commitdiff
Add REST API validation using OpenAPI schema
authorStephen Finucane <stephen@that.guru>
Thu, 15 Nov 2018 12:57:58 +0000 (13:57 +0100)
committerStephen Finucane <stephen@that.guru>
Sat, 22 Dec 2018 17:19:07 +0000 (17:19 +0000)
Add validation using the rather excellent 'openapi_core' library. The
biggest issue we have to contend with is the fact that 'openapi_core'
expects us to be able to provide a templated URL string for each request
(e.g. '/api/patches/123/' would become '/api/patches/<id>/') and Django
doesn't provide a way to do this [*]. We work around this by
reverse-engineering some of the Django code to turn a URL to its
matching regex, which we can then easily convert into a template string.
It's kind of hacky and not at all portable but, crucially, it does work
and has highlighted some nice bugs in the API that have already merged.

Going forward, we can probably modify 'openapi_core' somewhat to remove
the need for the templated URL string. If and when this happens, most of
the funkier code here can happily go away.

[*] Django 2.0+ [1] does actually provide a way to do template
string-based URLs and in fact recommends them now, with regexes being
reserved for more advanced corner cases. However, we don't want to drop
support for the Django 1.11 yet as it is the most recent LTS release.

[1] https://docs.djangoproject.com/en/2.1/ref/urls/#django.urls.path

Signed-off-by: Stephen Finucane <stephen@that.guru>
13 files changed:
patchwork/tests/api/test_bundle.py
patchwork/tests/api/test_check.py
patchwork/tests/api/test_comment.py
patchwork/tests/api/test_cover.py
patchwork/tests/api/test_event.py
patchwork/tests/api/test_patch.py
patchwork/tests/api/test_person.py
patchwork/tests/api/test_project.py
patchwork/tests/api/test_series.py
patchwork/tests/api/test_user.py
patchwork/tests/api/utils.py
patchwork/tests/api/validator.py [new file with mode: 0644]
requirements-test.txt

index e33c25ef06a8abde24ce17dfb7fcecd9ff44cf27..303c500c395bc8821611b8626d9742e26a8d5fe3 100644 (file)
@@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestBundleAPI(APITestCase):
+class TestBundleAPI(utils.APITestCase):
     fixtures = ['default_tags']
 
     @staticmethod
index e784ca9297e9bac5acf2c63ba2e03b6641f6ff0b..0c10b94553d30afb3c097aa82b08bbb12af44e5e 100644 (file)
@@ -18,15 +18,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestCheckAPI(APITestCase):
+class TestCheckAPI(utils.APITestCase):
     fixtures = ['default_tags']
 
     def api_url(self, item=None):
index 56aaa2002e3a5a5680d259d1830187e46b1154f2..f48bfce1abf06ac948d51dd0956bc8c9b9069368 100644 (file)
@@ -17,15 +17,10 @@ from patchwork.tests.utils import SAMPLE_CONTENT
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestCoverComments(APITestCase):
+class TestCoverComments(utils.APITestCase):
     @staticmethod
     def api_url(cover, version=None):
         kwargs = {}
@@ -76,7 +71,7 @@ class TestCoverComments(APITestCase):
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestPatchComments(APITestCase):
+class TestPatchComments(utils.APITestCase):
     @staticmethod
     def api_url(patch, version=None):
         kwargs = {}
index 8f96f3878364dffd2b77e17f4a6fdcb457c9be50..0a0bf041abc54c77312a7e3b9a95ebb7041da278 100644 (file)
@@ -12,21 +12,14 @@ from django.urls import reverse
 from patchwork.tests.api import utils
 from patchwork.tests.utils import create_cover
 from patchwork.tests.utils import create_maintainer
-from patchwork.tests.utils import create_person
-from patchwork.tests.utils import create_project
 from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestCoverLetterAPI(APITestCase):
+class TestCoverLetterAPI(utils.APITestCase):
     fixtures = ['default_tags']
 
     @staticmethod
index a2e89f53e0396a0369fc08208cbd4b914fdf0fa4..8816538fa071df0915b68611861751e3cdee3795 100644 (file)
@@ -19,15 +19,10 @@ from patchwork.tests.utils import create_state
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestEventAPI(APITestCase):
+class TestEventAPI(utils.APITestCase):
 
     @staticmethod
     def api_url(version=None):
index b501392c6aef3d6360fd09b20053618da81c14f2..82ae01841e7b673a2cd906d8969615818272bf7f 100644 (file)
@@ -21,15 +21,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestPatchAPI(APITestCase):
+class TestPatchAPI(utils.APITestCase):
     fixtures = ['default_tags']
 
     @staticmethod
index aad37a7aab4f07418c7fb57b84b690f146fe2f6c..6bd3cb67cdb1f2c21688d70489538cb0067f1f19 100644 (file)
@@ -15,15 +15,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestPersonAPI(APITestCase):
+class TestPersonAPI(utils.APITestCase):
 
     @staticmethod
     def api_url(item=None):
index 77ac0b4aadfaf732a2eeb60e899a9e138785b3b3..5a7676740affe779b5c952d7501bbda827acb830 100644 (file)
@@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestProjectAPI(APITestCase):
+class TestProjectAPI(utils.APITestCase):
 
     @staticmethod
     def api_url(item=None, version=None):
index aecd8b04fb2b9a53b4b90b487d781eedf4cf61f7..132791202f9490500917120b8200f1a43083484a 100644 (file)
@@ -19,15 +19,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestSeriesAPI(APITestCase):
+class TestSeriesAPI(utils.APITestCase):
     fixtures = ['default_tags']
 
     @staticmethod
index c6114ee635bd1ed0120293a66a0f4d85a00b8544..dfc4ddf150599517b3c4899dcdf5e136751365d1 100644 (file)
@@ -14,15 +14,10 @@ from patchwork.tests.utils import create_user
 
 if settings.ENABLE_REST_API:
     from rest_framework import status
-    from rest_framework.test import APITestCase
-else:
-    # stub out APITestCase
-    from django.test import TestCase
-    APITestCase = TestCase  # noqa
 
 
 @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestUserAPI(APITestCase):
+class TestUserAPI(utils.APITestCase):
 
     @staticmethod
     def api_url(item=None):
index 1097bb07a5715e28f855ce5adeb4a080f5303c3d..0c232d04b5da21359ebb63e6a5de699b7ffe7c8d 100644 (file)
@@ -7,7 +7,19 @@ import functools
 import json
 import os
 
-# docs/examples
+from django.conf import settings
+from django.test import testcases
+
+from patchwork.tests.api import validator
+
+if settings.ENABLE_REST_API:
+    from rest_framework.test import APIClient as BaseAPIClient
+    from rest_framework.test import APIRequestFactory
+else:
+    from django.test import Client as BaseAPIClient
+
+
+# docs/api/samples
 OUT_DIR = os.path.join(
     os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir,
     os.pardir, 'docs', 'api', 'samples')
@@ -91,3 +103,55 @@ def store_samples(filename):
         return wrapper
 
     return inner
+
+
+class APIClient(BaseAPIClient):
+
+    def __init__(self, *args, **kwargs):
+        super(APIClient, self).__init__(*args, **kwargs)
+        self.factory = APIRequestFactory()
+
+    def get(self, path, data=None, follow=False, **extra):
+        request = self.factory.get(
+            path, data=data, SERVER_NAME='example.com', **extra)
+        response = super(APIClient, self).get(
+            path, data=data, follow=follow, SERVER_NAME='example.com', **extra)
+        validator.validate_data(path, request, response)
+        return response
+
+    def post(self, path, data=None, format=None, content_type=None,
+             follow=False, **extra):
+        request = self.factory.post(
+            path, data=data, format='json', content_type=content_type,
+            SERVER_NAME='example.com', **extra)
+        response = super(APIClient, self).post(
+            path, data=data, format='json', content_type=content_type,
+            follow=follow, SERVER_NAME='example.com', **extra)
+        validator.validate_data(path, request, response)
+        return response
+
+    def put(self, path, data=None, format=None, content_type=None,
+            follow=False, **extra):
+        request = self.factory.put(
+            path, data=data, format='json', content_type=content_type,
+            SERVER_NAME='example.com', **extra)
+        response = super(APIClient, self).put(
+            path, data=data, format='json', content_type=content_type,
+            follow=follow, SERVER_NAME='example.com', **extra)
+        validator.validate_data(path, request, response)
+        return response
+
+    def patch(self, path, data=None, format=None, content_type=None,
+              follow=False, **extra):
+        request = self.factory.patch(
+            path, data=data, format='json', content_type=content_type,
+            SERVER_NAME='example.com', **extra)
+        response = super(APIClient, self).patch(
+            path, data=data, format='json', content_type=content_type,
+            follow=follow, SERVER_NAME='example.com', **extra)
+        validator.validate_data(path, request, response)
+        return response
+
+
+class APITestCase(testcases.TestCase):
+    client_class = APIClient
diff --git a/patchwork/tests/api/validator.py b/patchwork/tests/api/validator.py
new file mode 100644 (file)
index 0000000..3f13847
--- /dev/null
@@ -0,0 +1,317 @@
+# Patchwork - automated patch tracking system
+# Copyright (C) 2018 Stephen Finucane <stephen@that.guru>
+#
+# SPDX-License-Identifier: GPL-2.0-or-later
+
+import os
+import re
+
+import django
+from django.urls import resolve
+from django.urls.resolvers import get_resolver
+from django.utils import six
+import openapi_core
+from openapi_core.schema.schemas.models import Format
+from openapi_core.wrappers.base import BaseOpenAPIResponse
+from openapi_core.wrappers.base import BaseOpenAPIRequest
+from openapi_core.validation.request.validators import RequestValidator
+from openapi_core.validation.response.validators import ResponseValidator
+from openapi_core.schema.parameters.exceptions import OpenAPIParameterError
+from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError
+from rest_framework import status
+import yaml
+
+# docs/api/schemas
+SCHEMAS_DIR = os.path.join(
+    os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir,
+    os.pardir, 'docs', 'api', 'schemas')
+
+HEADER_REGEXES = (
+    re.compile(r'^HTTP_.+$'), re.compile(r'^CONTENT_TYPE$'),
+    re.compile(r'^CONTENT_LENGTH$'))
+
+_LOADED_SPECS = {}
+
+
+class RegexValidator(object):
+
+    def __init__(self, regex):
+        self.regex = re.compile(regex, re.IGNORECASE)
+
+    def __call__(self, value):
+        if not isinstance(value, six.text_type):
+            return False
+
+        if not value:
+            return True
+
+        return self.regex.match(value)
+
+
+CUSTOM_FORMATTERS = {
+    'uri': Format(six.text_type, RegexValidator(
+        r'^(?:http|ftp)s?://'
+        r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'  # noqa
+        r'localhost|'
+        r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
+        r'(?::\d+)?'
+        r'(?:/?|[/?]\S+)$')),
+    'iso8601': Format(six.text_type, RegexValidator(
+        r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d{6}$')),
+    'email': Format(six.text_type, RegexValidator(
+        r'[^@]+@[^@]+\.[^@]+')),
+}
+
+
+def _extract_headers(request):
+    request_headers = {}
+    for header in request.META:
+        for regex in HEADER_REGEXES:
+            if regex.match(header):
+                request_headers[header] = request.META[header]
+
+    return request_headers
+
+
+def _resolve_django1x(path, resolver=None):
+    """Resolve a given path to its matching regex (Django 1.x).
+
+    This is essentially a re-implementation of ``RegexURLResolver.resolve``
+    that builds and returns the matched regex instead of the view itself.
+
+    >>> _resolve_django1x('/api/1.0/patches/1/checks/')
+    "^api/(?:(?P<version>(1.0|1.1))/)patches/(?P<patch_id>[^/]+)/checks/$"
+    """
+    from django.urls.resolvers import RegexURLResolver  # noqa
+
+    resolver = resolver or get_resolver()
+    match = resolver.regex.search(path)
+
+    if not match:
+        return
+
+    if isinstance(resolver, RegexURLResolver):
+        sub_path = path[match.end():]
+        for sub_resolver in resolver.url_patterns:
+            sub_match = _resolve_django1x(sub_path, sub_resolver)
+            if not sub_match:
+                continue
+
+            kwargs = dict(match.groupdict())
+            kwargs.update(sub_match[2])
+            args = sub_match[1]
+            if not kwargs:
+                args = match.groups() + args
+
+            regex = resolver.regex.pattern + sub_match[0].lstrip('^')
+
+            return regex, args, kwargs
+    else:  # RegexURLPattern
+        kwargs = match.groupdict()
+        args = () if kwargs else match.groups()
+        return resolver.regex.pattern, args, kwargs
+
+
+def _resolve_django2x(path, resolver=None):
+    """Resolve a given path to its matching regex (Django 2.x).
+
+    This is essentially a re-implementation of ``URLResolver.resolve`` that
+    builds and returns the matched regex instead of the view itself.
+
+    >>> _resolve_django2x('/api/1.0/patches/1/checks/')
+    "^api/(?:(?P<version>(1.0|1.1))/)patches/(?P<patch_id>[^/]+)/checks/$"
+    """
+    from django.urls.resolvers import URLResolver  # noqa
+    from django.urls.resolvers import RegexPattern  # noqa
+
+    resolver = resolver or get_resolver()
+    match = resolver.pattern.match(path)
+
+    # we dont handle any other type of pattern at the moment
+    assert isinstance(resolver.pattern, RegexPattern)
+
+    if not match:
+        return
+
+    if isinstance(resolver, URLResolver):
+        sub_path, args, kwargs = match
+        for sub_resolver in resolver.url_patterns:
+            sub_match = _resolve_django2x(sub_path, sub_resolver)
+            if not sub_match:
+                continue
+
+            kwargs.update(sub_match[2])
+            args += sub_match[1]
+
+            regex = resolver.pattern._regex + sub_match[0].lstrip('^')
+
+            return regex, args, kwargs
+    else:
+        _, args, kwargs = match
+        return resolver.pattern._regex, args, kwargs
+
+
+if django.VERSION < (2, 0):
+    _resolve = _resolve_django1x
+else:
+    _resolve = _resolve_django2x
+
+
+def _resolve_path_to_kwargs(path):
+    """Convert a path to the kwargs used to resolve it.
+
+    >>> resolve_path_to_kwargs('/api/1.0/patches/1/checks/')
+    {"patch_id": 1}
+    """
+    # TODO(stephenfin): Handle definition by args
+    _, _, kwargs = _resolve(path)
+
+    results = {}
+    for key, value in kwargs.items():
+        if key == 'version':
+            continue
+
+        if key == 'pk':
+            key = 'id'
+
+        results[key] = value
+
+    return results
+
+
+def _resolve_path_to_template(path):
+    """Convert a path to a template string.
+
+    >>> resolve_path_to_template('/api/1.0/patches/1/checks/')
+    "/api/{version}/patches/{patch_id}/checks/"
+    """
+    regex, _, _ = _resolve(path)
+    regex = re.match(regex, path)
+
+    result = ''
+    prev_index = 0
+    for index, group in enumerate(regex.groups(), 1):
+        if not group:  # group didn't match anything
+            continue
+
+        result += path[prev_index:regex.start(index)]
+        prev_index = regex.end(index)
+        # groupindex keys by name, not index. Switch that.
+        for name, index_ in regex.re.groupindex.items():
+            if index_ == (index):
+                # special-case version group
+                if name == 'version':
+                    result += group
+                    break
+
+                if name == 'pk':
+                    name = 'id'
+
+                result += '{%s}' % name
+                break
+
+    result += path[prev_index:]
+
+    return result
+
+
+def _load_spec(version):
+    global _LOADED_SPECS
+
+    if _LOADED_SPECS.get(version):
+        return _LOADED_SPECS[version]
+
+    spec_path = os.path.join(SCHEMAS_DIR,
+                             'v{}'.format(version) if version else 'latest',
+                             'patchwork.yaml')
+
+    with open(spec_path, 'r') as fh:
+        data = yaml.load(fh)
+
+    _LOADED_SPECS[version] = openapi_core.create_spec(data)
+
+    return _LOADED_SPECS[version]
+
+
+class DRFOpenAPIRequest(BaseOpenAPIRequest):
+
+    def __init__(self, request):
+        self.request = request
+
+    @property
+    def host_url(self):
+        return self.request.get_host()
+
+    @property
+    def path(self):
+        return self.request.path
+
+    @property
+    def method(self):
+        return self.request.method.lower()
+
+    @property
+    def path_pattern(self):
+        return _resolve_path_to_template(self.request.path_info)
+
+    @property
+    def parameters(self):
+        return {
+            'path': _resolve_path_to_kwargs(self.request.path_info),
+            'query': self.request.GET,
+            'header': _extract_headers(self.request),
+            'cookie': self.request.COOKIES,
+        }
+
+    @property
+    def body(self):
+        return self.request.body.decode('utf-8')
+
+    @property
+    def mimetype(self):
+        return self.request.content_type
+
+
+class DRFOpenAPIResponse(BaseOpenAPIResponse):
+
+    def __init__(self, response):
+        self.response = response
+
+    @property
+    def data(self):
+        return self.response.content.decode('utf-8')
+
+    @property
+    def status_code(self):
+        return self.response.status_code
+
+    @property
+    def mimetype(self):
+        # TODO(stephenfin): Why isn't this populated?
+        return 'application/json'
+
+
+def validate_data(path, request, response):
+    if response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED:
+        return
+
+    spec = _load_spec(resolve(path).kwargs.get('version'))
+    request = DRFOpenAPIRequest(request)
+    response = DRFOpenAPIResponse(response)
+
+    # request
+    validator = RequestValidator(spec, custom_formatters=CUSTOM_FORMATTERS)
+    result = validator.validate(request)
+    try:
+        result.raise_for_errors()
+    except OpenAPIMediaTypeError:
+        assert response.status_code == status.HTTP_400_BAD_REQUEST
+    except OpenAPIParameterError:
+        # TODO(stephenfin): In API v2.0, this should be an error. As things
+        # stand, we silently ignore these issues.
+        assert response.status_code == status.HTTP_200_OK
+
+    # response
+    validator = ResponseValidator(spec, custom_formatters=CUSTOM_FORMATTERS)
+    result = validator.validate(request, response)
+    result.raise_for_errors()
index 6c9bd88d516efac9bc4b31a3d088d6491d22a484..cfb8ce749b54be176130ae8e19d4bfc0b47e4e7c 100644 (file)
@@ -2,3 +2,4 @@ mysqlclient==1.3.13
 psycopg2-binary==2.7.6
 sqlparse==0.2.4
 python-dateutil==2.7.5
+openapi-core==0.7.1