From: Stephen Finucane Date: Thu, 15 Nov 2018 12:57:58 +0000 (+0100) Subject: Add REST API validation using OpenAPI schema X-Git-Tag: v2.2.0-rc1~179 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=08d1459a4a40;p=thirdparty%2Fpatchwork.git Add REST API validation using OpenAPI schema Add validation using the rather excellent 'openapi_core' library. The biggest issue we have to contend with is the fact that 'openapi_core' expects us to be able to provide a templated URL string for each request (e.g. '/api/patches/123/' would become '/api/patches//') and Django doesn't provide a way to do this [*]. We work around this by reverse-engineering some of the Django code to turn a URL to its matching regex, which we can then easily convert into a template string. It's kind of hacky and not at all portable but, crucially, it does work and has highlighted some nice bugs in the API that have already merged. Going forward, we can probably modify 'openapi_core' somewhat to remove the need for the templated URL string. If and when this happens, most of the funkier code here can happily go away. [*] Django 2.0+ [1] does actually provide a way to do template string-based URLs and in fact recommends them now, with regexes being reserved for more advanced corner cases. However, we don't want to drop support for the Django 1.11 yet as it is the most recent LTS release. [1] https://docs.djangoproject.com/en/2.1/ref/urls/#django.urls.path Signed-off-by: Stephen Finucane --- diff --git a/patchwork/tests/api/test_bundle.py b/patchwork/tests/api/test_bundle.py index e33c25ef..303c500c 100644 --- a/patchwork/tests/api/test_bundle.py +++ b/patchwork/tests/api/test_bundle.py @@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestBundleAPI(APITestCase): +class TestBundleAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_check.py b/patchwork/tests/api/test_check.py index e784ca92..0c10b945 100644 --- a/patchwork/tests/api/test_check.py +++ b/patchwork/tests/api/test_check.py @@ -18,15 +18,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestCheckAPI(APITestCase): +class TestCheckAPI(utils.APITestCase): fixtures = ['default_tags'] def api_url(self, item=None): diff --git a/patchwork/tests/api/test_comment.py b/patchwork/tests/api/test_comment.py index 56aaa200..f48bfce1 100644 --- a/patchwork/tests/api/test_comment.py +++ b/patchwork/tests/api/test_comment.py @@ -17,15 +17,10 @@ from patchwork.tests.utils import SAMPLE_CONTENT if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestCoverComments(APITestCase): +class TestCoverComments(utils.APITestCase): @staticmethod def api_url(cover, version=None): kwargs = {} @@ -76,7 +71,7 @@ class TestCoverComments(APITestCase): @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestPatchComments(APITestCase): +class TestPatchComments(utils.APITestCase): @staticmethod def api_url(patch, version=None): kwargs = {} diff --git a/patchwork/tests/api/test_cover.py b/patchwork/tests/api/test_cover.py index 8f96f387..0a0bf041 100644 --- a/patchwork/tests/api/test_cover.py +++ b/patchwork/tests/api/test_cover.py @@ -12,21 +12,14 @@ from django.urls import reverse from patchwork.tests.api import utils from patchwork.tests.utils import create_cover from patchwork.tests.utils import create_maintainer -from patchwork.tests.utils import create_person -from patchwork.tests.utils import create_project from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestCoverLetterAPI(APITestCase): +class TestCoverLetterAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_event.py b/patchwork/tests/api/test_event.py index a2e89f53..8816538f 100644 --- a/patchwork/tests/api/test_event.py +++ b/patchwork/tests/api/test_event.py @@ -19,15 +19,10 @@ from patchwork.tests.utils import create_state if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestEventAPI(APITestCase): +class TestEventAPI(utils.APITestCase): @staticmethod def api_url(version=None): diff --git a/patchwork/tests/api/test_patch.py b/patchwork/tests/api/test_patch.py index b501392c..82ae0184 100644 --- a/patchwork/tests/api/test_patch.py +++ b/patchwork/tests/api/test_patch.py @@ -21,15 +21,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestPatchAPI(APITestCase): +class TestPatchAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_person.py b/patchwork/tests/api/test_person.py index aad37a7a..6bd3cb67 100644 --- a/patchwork/tests/api/test_person.py +++ b/patchwork/tests/api/test_person.py @@ -15,15 +15,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestPersonAPI(APITestCase): +class TestPersonAPI(utils.APITestCase): @staticmethod def api_url(item=None): diff --git a/patchwork/tests/api/test_project.py b/patchwork/tests/api/test_project.py index 77ac0b4a..5a767674 100644 --- a/patchwork/tests/api/test_project.py +++ b/patchwork/tests/api/test_project.py @@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestProjectAPI(APITestCase): +class TestProjectAPI(utils.APITestCase): @staticmethod def api_url(item=None, version=None): diff --git a/patchwork/tests/api/test_series.py b/patchwork/tests/api/test_series.py index aecd8b04..13279120 100644 --- a/patchwork/tests/api/test_series.py +++ b/patchwork/tests/api/test_series.py @@ -19,15 +19,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestSeriesAPI(APITestCase): +class TestSeriesAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_user.py b/patchwork/tests/api/test_user.py index c6114ee6..dfc4ddf1 100644 --- a/patchwork/tests/api/test_user.py +++ b/patchwork/tests/api/test_user.py @@ -14,15 +14,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestUserAPI(APITestCase): +class TestUserAPI(utils.APITestCase): @staticmethod def api_url(item=None): diff --git a/patchwork/tests/api/utils.py b/patchwork/tests/api/utils.py index 1097bb07..0c232d04 100644 --- a/patchwork/tests/api/utils.py +++ b/patchwork/tests/api/utils.py @@ -7,7 +7,19 @@ import functools import json import os -# docs/examples +from django.conf import settings +from django.test import testcases + +from patchwork.tests.api import validator + +if settings.ENABLE_REST_API: + from rest_framework.test import APIClient as BaseAPIClient + from rest_framework.test import APIRequestFactory +else: + from django.test import Client as BaseAPIClient + + +# docs/api/samples OUT_DIR = os.path.join( os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir, os.pardir, 'docs', 'api', 'samples') @@ -91,3 +103,55 @@ def store_samples(filename): return wrapper return inner + + +class APIClient(BaseAPIClient): + + def __init__(self, *args, **kwargs): + super(APIClient, self).__init__(*args, **kwargs) + self.factory = APIRequestFactory() + + def get(self, path, data=None, follow=False, **extra): + request = self.factory.get( + path, data=data, SERVER_NAME='example.com', **extra) + response = super(APIClient, self).get( + path, data=data, follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + def post(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + request = self.factory.post( + path, data=data, format='json', content_type=content_type, + SERVER_NAME='example.com', **extra) + response = super(APIClient, self).post( + path, data=data, format='json', content_type=content_type, + follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + def put(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + request = self.factory.put( + path, data=data, format='json', content_type=content_type, + SERVER_NAME='example.com', **extra) + response = super(APIClient, self).put( + path, data=data, format='json', content_type=content_type, + follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + def patch(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + request = self.factory.patch( + path, data=data, format='json', content_type=content_type, + SERVER_NAME='example.com', **extra) + response = super(APIClient, self).patch( + path, data=data, format='json', content_type=content_type, + follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + +class APITestCase(testcases.TestCase): + client_class = APIClient diff --git a/patchwork/tests/api/validator.py b/patchwork/tests/api/validator.py new file mode 100644 index 00000000..3f138479 --- /dev/null +++ b/patchwork/tests/api/validator.py @@ -0,0 +1,317 @@ +# Patchwork - automated patch tracking system +# Copyright (C) 2018 Stephen Finucane +# +# SPDX-License-Identifier: GPL-2.0-or-later + +import os +import re + +import django +from django.urls import resolve +from django.urls.resolvers import get_resolver +from django.utils import six +import openapi_core +from openapi_core.schema.schemas.models import Format +from openapi_core.wrappers.base import BaseOpenAPIResponse +from openapi_core.wrappers.base import BaseOpenAPIRequest +from openapi_core.validation.request.validators import RequestValidator +from openapi_core.validation.response.validators import ResponseValidator +from openapi_core.schema.parameters.exceptions import OpenAPIParameterError +from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError +from rest_framework import status +import yaml + +# docs/api/schemas +SCHEMAS_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir, + os.pardir, 'docs', 'api', 'schemas') + +HEADER_REGEXES = ( + re.compile(r'^HTTP_.+$'), re.compile(r'^CONTENT_TYPE$'), + re.compile(r'^CONTENT_LENGTH$')) + +_LOADED_SPECS = {} + + +class RegexValidator(object): + + def __init__(self, regex): + self.regex = re.compile(regex, re.IGNORECASE) + + def __call__(self, value): + if not isinstance(value, six.text_type): + return False + + if not value: + return True + + return self.regex.match(value) + + +CUSTOM_FORMATTERS = { + 'uri': Format(six.text_type, RegexValidator( + r'^(?:http|ftp)s?://' + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # noqa + r'localhost|' + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' + r'(?::\d+)?' + r'(?:/?|[/?]\S+)$')), + 'iso8601': Format(six.text_type, RegexValidator( + r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d{6}$')), + 'email': Format(six.text_type, RegexValidator( + r'[^@]+@[^@]+\.[^@]+')), +} + + +def _extract_headers(request): + request_headers = {} + for header in request.META: + for regex in HEADER_REGEXES: + if regex.match(header): + request_headers[header] = request.META[header] + + return request_headers + + +def _resolve_django1x(path, resolver=None): + """Resolve a given path to its matching regex (Django 1.x). + + This is essentially a re-implementation of ``RegexURLResolver.resolve`` + that builds and returns the matched regex instead of the view itself. + + >>> _resolve_django1x('/api/1.0/patches/1/checks/') + "^api/(?:(?P(1.0|1.1))/)patches/(?P[^/]+)/checks/$" + """ + from django.urls.resolvers import RegexURLResolver # noqa + + resolver = resolver or get_resolver() + match = resolver.regex.search(path) + + if not match: + return + + if isinstance(resolver, RegexURLResolver): + sub_path = path[match.end():] + for sub_resolver in resolver.url_patterns: + sub_match = _resolve_django1x(sub_path, sub_resolver) + if not sub_match: + continue + + kwargs = dict(match.groupdict()) + kwargs.update(sub_match[2]) + args = sub_match[1] + if not kwargs: + args = match.groups() + args + + regex = resolver.regex.pattern + sub_match[0].lstrip('^') + + return regex, args, kwargs + else: # RegexURLPattern + kwargs = match.groupdict() + args = () if kwargs else match.groups() + return resolver.regex.pattern, args, kwargs + + +def _resolve_django2x(path, resolver=None): + """Resolve a given path to its matching regex (Django 2.x). + + This is essentially a re-implementation of ``URLResolver.resolve`` that + builds and returns the matched regex instead of the view itself. + + >>> _resolve_django2x('/api/1.0/patches/1/checks/') + "^api/(?:(?P(1.0|1.1))/)patches/(?P[^/]+)/checks/$" + """ + from django.urls.resolvers import URLResolver # noqa + from django.urls.resolvers import RegexPattern # noqa + + resolver = resolver or get_resolver() + match = resolver.pattern.match(path) + + # we dont handle any other type of pattern at the moment + assert isinstance(resolver.pattern, RegexPattern) + + if not match: + return + + if isinstance(resolver, URLResolver): + sub_path, args, kwargs = match + for sub_resolver in resolver.url_patterns: + sub_match = _resolve_django2x(sub_path, sub_resolver) + if not sub_match: + continue + + kwargs.update(sub_match[2]) + args += sub_match[1] + + regex = resolver.pattern._regex + sub_match[0].lstrip('^') + + return regex, args, kwargs + else: + _, args, kwargs = match + return resolver.pattern._regex, args, kwargs + + +if django.VERSION < (2, 0): + _resolve = _resolve_django1x +else: + _resolve = _resolve_django2x + + +def _resolve_path_to_kwargs(path): + """Convert a path to the kwargs used to resolve it. + + >>> resolve_path_to_kwargs('/api/1.0/patches/1/checks/') + {"patch_id": 1} + """ + # TODO(stephenfin): Handle definition by args + _, _, kwargs = _resolve(path) + + results = {} + for key, value in kwargs.items(): + if key == 'version': + continue + + if key == 'pk': + key = 'id' + + results[key] = value + + return results + + +def _resolve_path_to_template(path): + """Convert a path to a template string. + + >>> resolve_path_to_template('/api/1.0/patches/1/checks/') + "/api/{version}/patches/{patch_id}/checks/" + """ + regex, _, _ = _resolve(path) + regex = re.match(regex, path) + + result = '' + prev_index = 0 + for index, group in enumerate(regex.groups(), 1): + if not group: # group didn't match anything + continue + + result += path[prev_index:regex.start(index)] + prev_index = regex.end(index) + # groupindex keys by name, not index. Switch that. + for name, index_ in regex.re.groupindex.items(): + if index_ == (index): + # special-case version group + if name == 'version': + result += group + break + + if name == 'pk': + name = 'id' + + result += '{%s}' % name + break + + result += path[prev_index:] + + return result + + +def _load_spec(version): + global _LOADED_SPECS + + if _LOADED_SPECS.get(version): + return _LOADED_SPECS[version] + + spec_path = os.path.join(SCHEMAS_DIR, + 'v{}'.format(version) if version else 'latest', + 'patchwork.yaml') + + with open(spec_path, 'r') as fh: + data = yaml.load(fh) + + _LOADED_SPECS[version] = openapi_core.create_spec(data) + + return _LOADED_SPECS[version] + + +class DRFOpenAPIRequest(BaseOpenAPIRequest): + + def __init__(self, request): + self.request = request + + @property + def host_url(self): + return self.request.get_host() + + @property + def path(self): + return self.request.path + + @property + def method(self): + return self.request.method.lower() + + @property + def path_pattern(self): + return _resolve_path_to_template(self.request.path_info) + + @property + def parameters(self): + return { + 'path': _resolve_path_to_kwargs(self.request.path_info), + 'query': self.request.GET, + 'header': _extract_headers(self.request), + 'cookie': self.request.COOKIES, + } + + @property + def body(self): + return self.request.body.decode('utf-8') + + @property + def mimetype(self): + return self.request.content_type + + +class DRFOpenAPIResponse(BaseOpenAPIResponse): + + def __init__(self, response): + self.response = response + + @property + def data(self): + return self.response.content.decode('utf-8') + + @property + def status_code(self): + return self.response.status_code + + @property + def mimetype(self): + # TODO(stephenfin): Why isn't this populated? + return 'application/json' + + +def validate_data(path, request, response): + if response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED: + return + + spec = _load_spec(resolve(path).kwargs.get('version')) + request = DRFOpenAPIRequest(request) + response = DRFOpenAPIResponse(response) + + # request + validator = RequestValidator(spec, custom_formatters=CUSTOM_FORMATTERS) + result = validator.validate(request) + try: + result.raise_for_errors() + except OpenAPIMediaTypeError: + assert response.status_code == status.HTTP_400_BAD_REQUEST + except OpenAPIParameterError: + # TODO(stephenfin): In API v2.0, this should be an error. As things + # stand, we silently ignore these issues. + assert response.status_code == status.HTTP_200_OK + + # response + validator = ResponseValidator(spec, custom_formatters=CUSTOM_FORMATTERS) + result = validator.validate(request, response) + result.raise_for_errors() diff --git a/requirements-test.txt b/requirements-test.txt index 6c9bd88d..cfb8ce74 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,3 +2,4 @@ mysqlclient==1.3.13 psycopg2-binary==2.7.6 sqlparse==0.2.4 python-dateutil==2.7.5 +openapi-core==0.7.1