diff options
-rw-r--r-- | patchwork/tests/api/test_bundle.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_check.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_comment.py | 9 | ||||
-rw-r--r-- | patchwork/tests/api/test_cover.py | 9 | ||||
-rw-r--r-- | patchwork/tests/api/test_event.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_patch.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_person.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_project.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_series.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/test_user.py | 7 | ||||
-rw-r--r-- | patchwork/tests/api/utils.py | 66 | ||||
-rw-r--r-- | patchwork/tests/api/validator.py | 317 | ||||
-rw-r--r-- | requirements-test.txt | 1 |
13 files changed, 394 insertions, 64 deletions
diff --git a/patchwork/tests/api/test_bundle.py b/patchwork/tests/api/test_bundle.py index e33c25e..303c500 100644 --- a/patchwork/tests/api/test_bundle.py +++ b/patchwork/tests/api/test_bundle.py @@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestBundleAPI(APITestCase): +class TestBundleAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_check.py b/patchwork/tests/api/test_check.py index e784ca9..0c10b94 100644 --- a/patchwork/tests/api/test_check.py +++ b/patchwork/tests/api/test_check.py @@ -18,15 +18,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestCheckAPI(APITestCase): +class TestCheckAPI(utils.APITestCase): fixtures = ['default_tags'] def api_url(self, item=None): diff --git a/patchwork/tests/api/test_comment.py b/patchwork/tests/api/test_comment.py index 56aaa20..f48bfce 100644 --- a/patchwork/tests/api/test_comment.py +++ b/patchwork/tests/api/test_comment.py @@ -17,15 +17,10 @@ from patchwork.tests.utils import SAMPLE_CONTENT if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestCoverComments(APITestCase): +class TestCoverComments(utils.APITestCase): @staticmethod def api_url(cover, version=None): kwargs = {} @@ -76,7 +71,7 @@ class TestCoverComments(APITestCase): @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestPatchComments(APITestCase): +class TestPatchComments(utils.APITestCase): @staticmethod def api_url(patch, version=None): kwargs = {} diff --git a/patchwork/tests/api/test_cover.py b/patchwork/tests/api/test_cover.py index 8f96f38..0a0bf04 100644 --- a/patchwork/tests/api/test_cover.py +++ b/patchwork/tests/api/test_cover.py @@ -12,21 +12,14 @@ from django.urls import reverse from patchwork.tests.api import utils from patchwork.tests.utils import create_cover from patchwork.tests.utils import create_maintainer -from patchwork.tests.utils import create_person -from patchwork.tests.utils import create_project from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestCoverLetterAPI(APITestCase): +class TestCoverLetterAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_event.py b/patchwork/tests/api/test_event.py index a2e89f5..8816538 100644 --- a/patchwork/tests/api/test_event.py +++ b/patchwork/tests/api/test_event.py @@ -19,15 +19,10 @@ from patchwork.tests.utils import create_state if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestEventAPI(APITestCase): +class TestEventAPI(utils.APITestCase): @staticmethod def api_url(version=None): diff --git a/patchwork/tests/api/test_patch.py b/patchwork/tests/api/test_patch.py index b501392..82ae018 100644 --- a/patchwork/tests/api/test_patch.py +++ b/patchwork/tests/api/test_patch.py @@ -21,15 +21,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestPatchAPI(APITestCase): +class TestPatchAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_person.py b/patchwork/tests/api/test_person.py index aad37a7..6bd3cb6 100644 --- a/patchwork/tests/api/test_person.py +++ b/patchwork/tests/api/test_person.py @@ -15,15 +15,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestPersonAPI(APITestCase): +class TestPersonAPI(utils.APITestCase): @staticmethod def api_url(item=None): diff --git a/patchwork/tests/api/test_project.py b/patchwork/tests/api/test_project.py index 77ac0b4..5a76767 100644 --- a/patchwork/tests/api/test_project.py +++ b/patchwork/tests/api/test_project.py @@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestProjectAPI(APITestCase): +class TestProjectAPI(utils.APITestCase): @staticmethod def api_url(item=None, version=None): diff --git a/patchwork/tests/api/test_series.py b/patchwork/tests/api/test_series.py index aecd8b0..1327912 100644 --- a/patchwork/tests/api/test_series.py +++ b/patchwork/tests/api/test_series.py @@ -19,15 +19,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestSeriesAPI(APITestCase): +class TestSeriesAPI(utils.APITestCase): fixtures = ['default_tags'] @staticmethod diff --git a/patchwork/tests/api/test_user.py b/patchwork/tests/api/test_user.py index c6114ee..dfc4ddf 100644 --- a/patchwork/tests/api/test_user.py +++ b/patchwork/tests/api/test_user.py @@ -14,15 +14,10 @@ from patchwork.tests.utils import create_user if settings.ENABLE_REST_API: from rest_framework import status - from rest_framework.test import APITestCase -else: - # stub out APITestCase - from django.test import TestCase - APITestCase = TestCase # noqa @unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API') -class TestUserAPI(APITestCase): +class TestUserAPI(utils.APITestCase): @staticmethod def api_url(item=None): diff --git a/patchwork/tests/api/utils.py b/patchwork/tests/api/utils.py index 1097bb0..0c232d0 100644 --- a/patchwork/tests/api/utils.py +++ b/patchwork/tests/api/utils.py @@ -7,7 +7,19 @@ import functools import json import os -# docs/examples +from django.conf import settings +from django.test import testcases + +from patchwork.tests.api import validator + +if settings.ENABLE_REST_API: + from rest_framework.test import APIClient as BaseAPIClient + from rest_framework.test import APIRequestFactory +else: + from django.test import Client as BaseAPIClient + + +# docs/api/samples OUT_DIR = os.path.join( os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir, os.pardir, 'docs', 'api', 'samples') @@ -91,3 +103,55 @@ def store_samples(filename): return wrapper return inner + + +class APIClient(BaseAPIClient): + + def __init__(self, *args, **kwargs): + super(APIClient, self).__init__(*args, **kwargs) + self.factory = APIRequestFactory() + + def get(self, path, data=None, follow=False, **extra): + request = self.factory.get( + path, data=data, SERVER_NAME='example.com', **extra) + response = super(APIClient, self).get( + path, data=data, follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + def post(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + request = self.factory.post( + path, data=data, format='json', content_type=content_type, + SERVER_NAME='example.com', **extra) + response = super(APIClient, self).post( + path, data=data, format='json', content_type=content_type, + follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + def put(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + request = self.factory.put( + path, data=data, format='json', content_type=content_type, + SERVER_NAME='example.com', **extra) + response = super(APIClient, self).put( + path, data=data, format='json', content_type=content_type, + follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + def patch(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + request = self.factory.patch( + path, data=data, format='json', content_type=content_type, + SERVER_NAME='example.com', **extra) + response = super(APIClient, self).patch( + path, data=data, format='json', content_type=content_type, + follow=follow, SERVER_NAME='example.com', **extra) + validator.validate_data(path, request, response) + return response + + +class APITestCase(testcases.TestCase): + client_class = APIClient diff --git a/patchwork/tests/api/validator.py b/patchwork/tests/api/validator.py new file mode 100644 index 0000000..3f13847 --- /dev/null +++ b/patchwork/tests/api/validator.py @@ -0,0 +1,317 @@ +# Patchwork - automated patch tracking system +# Copyright (C) 2018 Stephen Finucane <stephen@that.guru> +# +# SPDX-License-Identifier: GPL-2.0-or-later + +import os +import re + +import django +from django.urls import resolve +from django.urls.resolvers import get_resolver +from django.utils import six +import openapi_core +from openapi_core.schema.schemas.models import Format +from openapi_core.wrappers.base import BaseOpenAPIResponse +from openapi_core.wrappers.base import BaseOpenAPIRequest +from openapi_core.validation.request.validators import RequestValidator +from openapi_core.validation.response.validators import ResponseValidator +from openapi_core.schema.parameters.exceptions import OpenAPIParameterError +from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError +from rest_framework import status +import yaml + +# docs/api/schemas +SCHEMAS_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir, + os.pardir, 'docs', 'api', 'schemas') + +HEADER_REGEXES = ( + re.compile(r'^HTTP_.+$'), re.compile(r'^CONTENT_TYPE$'), + re.compile(r'^CONTENT_LENGTH$')) + +_LOADED_SPECS = {} + + +class RegexValidator(object): + + def __init__(self, regex): + self.regex = re.compile(regex, re.IGNORECASE) + + def __call__(self, value): + if not isinstance(value, six.text_type): + return False + + if not value: + return True + + return self.regex.match(value) + + +CUSTOM_FORMATTERS = { + 'uri': Format(six.text_type, RegexValidator( + r'^(?:http|ftp)s?://' + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # noqa + r'localhost|' + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' + r'(?::\d+)?' + r'(?:/?|[/?]\S+)$')), + 'iso8601': Format(six.text_type, RegexValidator( + r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d{6}$')), + 'email': Format(six.text_type, RegexValidator( + r'[^@]+@[^@]+\.[^@]+')), +} + + +def _extract_headers(request): + request_headers = {} + for header in request.META: + for regex in HEADER_REGEXES: + if regex.match(header): + request_headers[header] = request.META[header] + + return request_headers + + +def _resolve_django1x(path, resolver=None): + """Resolve a given path to its matching regex (Django 1.x). + + This is essentially a re-implementation of ``RegexURLResolver.resolve`` + that builds and returns the matched regex instead of the view itself. + + >>> _resolve_django1x('/api/1.0/patches/1/checks/') + "^api/(?:(?P<version>(1.0|1.1))/)patches/(?P<patch_id>[^/]+)/checks/$" + """ + from django.urls.resolvers import RegexURLResolver # noqa + + resolver = resolver or get_resolver() + match = resolver.regex.search(path) + + if not match: + return + + if isinstance(resolver, RegexURLResolver): + sub_path = path[match.end():] + for sub_resolver in resolver.url_patterns: + sub_match = _resolve_django1x(sub_path, sub_resolver) + if not sub_match: + continue + + kwargs = dict(match.groupdict()) + kwargs.update(sub_match[2]) + args = sub_match[1] + if not kwargs: + args = match.groups() + args + + regex = resolver.regex.pattern + sub_match[0].lstrip('^') + + return regex, args, kwargs + else: # RegexURLPattern + kwargs = match.groupdict() + args = () if kwargs else match.groups() + return resolver.regex.pattern, args, kwargs + + +def _resolve_django2x(path, resolver=None): + """Resolve a given path to its matching regex (Django 2.x). + + This is essentially a re-implementation of ``URLResolver.resolve`` that + builds and returns the matched regex instead of the view itself. + + >>> _resolve_django2x('/api/1.0/patches/1/checks/') + "^api/(?:(?P<version>(1.0|1.1))/)patches/(?P<patch_id>[^/]+)/checks/$" + """ + from django.urls.resolvers import URLResolver # noqa + from django.urls.resolvers import RegexPattern # noqa + + resolver = resolver or get_resolver() + match = resolver.pattern.match(path) + + # we dont handle any other type of pattern at the moment + assert isinstance(resolver.pattern, RegexPattern) + + if not match: + return + + if isinstance(resolver, URLResolver): + sub_path, args, kwargs = match + for sub_resolver in resolver.url_patterns: + sub_match = _resolve_django2x(sub_path, sub_resolver) + if not sub_match: + continue + + kwargs.update(sub_match[2]) + args += sub_match[1] + + regex = resolver.pattern._regex + sub_match[0].lstrip('^') + + return regex, args, kwargs + else: + _, args, kwargs = match + return resolver.pattern._regex, args, kwargs + + +if django.VERSION < (2, 0): + _resolve = _resolve_django1x +else: + _resolve = _resolve_django2x + + +def _resolve_path_to_kwargs(path): + """Convert a path to the kwargs used to resolve it. + + >>> resolve_path_to_kwargs('/api/1.0/patches/1/checks/') + {"patch_id": 1} + """ + # TODO(stephenfin): Handle definition by args + _, _, kwargs = _resolve(path) + + results = {} + for key, value in kwargs.items(): + if key == 'version': + continue + + if key == 'pk': + key = 'id' + + results[key] = value + + return results + + +def _resolve_path_to_template(path): + """Convert a path to a template string. + + >>> resolve_path_to_template('/api/1.0/patches/1/checks/') + "/api/{version}/patches/{patch_id}/checks/" + """ + regex, _, _ = _resolve(path) + regex = re.match(regex, path) + + result = '' + prev_index = 0 + for index, group in enumerate(regex.groups(), 1): + if not group: # group didn't match anything + continue + + result += path[prev_index:regex.start(index)] + prev_index = regex.end(index) + # groupindex keys by name, not index. Switch that. + for name, index_ in regex.re.groupindex.items(): + if index_ == (index): + # special-case version group + if name == 'version': + result += group + break + + if name == 'pk': + name = 'id' + + result += '{%s}' % name + break + + result += path[prev_index:] + + return result + + +def _load_spec(version): + global _LOADED_SPECS + + if _LOADED_SPECS.get(version): + return _LOADED_SPECS[version] + + spec_path = os.path.join(SCHEMAS_DIR, + 'v{}'.format(version) if version else 'latest', + 'patchwork.yaml') + + with open(spec_path, 'r') as fh: + data = yaml.load(fh) + + _LOADED_SPECS[version] = openapi_core.create_spec(data) + + return _LOADED_SPECS[version] + + +class DRFOpenAPIRequest(BaseOpenAPIRequest): + + def __init__(self, request): + self.request = request + + @property + def host_url(self): + return self.request.get_host() + + @property + def path(self): + return self.request.path + + @property + def method(self): + return self.request.method.lower() + + @property + def path_pattern(self): + return _resolve_path_to_template(self.request.path_info) + + @property + def parameters(self): + return { + 'path': _resolve_path_to_kwargs(self.request.path_info), + 'query': self.request.GET, + 'header': _extract_headers(self.request), + 'cookie': self.request.COOKIES, + } + + @property + def body(self): + return self.request.body.decode('utf-8') + + @property + def mimetype(self): + return self.request.content_type + + +class DRFOpenAPIResponse(BaseOpenAPIResponse): + + def __init__(self, response): + self.response = response + + @property + def data(self): + return self.response.content.decode('utf-8') + + @property + def status_code(self): + return self.response.status_code + + @property + def mimetype(self): + # TODO(stephenfin): Why isn't this populated? + return 'application/json' + + +def validate_data(path, request, response): + if response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED: + return + + spec = _load_spec(resolve(path).kwargs.get('version')) + request = DRFOpenAPIRequest(request) + response = DRFOpenAPIResponse(response) + + # request + validator = RequestValidator(spec, custom_formatters=CUSTOM_FORMATTERS) + result = validator.validate(request) + try: + result.raise_for_errors() + except OpenAPIMediaTypeError: + assert response.status_code == status.HTTP_400_BAD_REQUEST + except OpenAPIParameterError: + # TODO(stephenfin): In API v2.0, this should be an error. As things + # stand, we silently ignore these issues. + assert response.status_code == status.HTTP_200_OK + + # response + validator = ResponseValidator(spec, custom_formatters=CUSTOM_FORMATTERS) + result = validator.validate(request, response) + result.raise_for_errors() diff --git a/requirements-test.txt b/requirements-test.txt index 6c9bd88..cfb8ce7 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,3 +2,4 @@ mysqlclient==1.3.13 psycopg2-binary==2.7.6 sqlparse==0.2.4 python-dateutil==2.7.5 +openapi-core==0.7.1 |