summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--patchwork/tests/api/test_bundle.py7
-rw-r--r--patchwork/tests/api/test_check.py7
-rw-r--r--patchwork/tests/api/test_comment.py9
-rw-r--r--patchwork/tests/api/test_cover.py9
-rw-r--r--patchwork/tests/api/test_event.py7
-rw-r--r--patchwork/tests/api/test_patch.py7
-rw-r--r--patchwork/tests/api/test_person.py7
-rw-r--r--patchwork/tests/api/test_project.py7
-rw-r--r--patchwork/tests/api/test_series.py7
-rw-r--r--patchwork/tests/api/test_user.py7
-rw-r--r--patchwork/tests/api/utils.py66
-rw-r--r--patchwork/tests/api/validator.py317
-rw-r--r--requirements-test.txt1
13 files changed, 394 insertions, 64 deletions
diff --git a/patchwork/tests/api/test_bundle.py b/patchwork/tests/api/test_bundle.py
index e33c25e..303c500 100644
--- a/patchwork/tests/api/test_bundle.py
+++ b/patchwork/tests/api/test_bundle.py
@@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestBundleAPI(APITestCase):
+class TestBundleAPI(utils.APITestCase):
fixtures = ['default_tags']
@staticmethod
diff --git a/patchwork/tests/api/test_check.py b/patchwork/tests/api/test_check.py
index e784ca9..0c10b94 100644
--- a/patchwork/tests/api/test_check.py
+++ b/patchwork/tests/api/test_check.py
@@ -18,15 +18,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestCheckAPI(APITestCase):
+class TestCheckAPI(utils.APITestCase):
fixtures = ['default_tags']
def api_url(self, item=None):
diff --git a/patchwork/tests/api/test_comment.py b/patchwork/tests/api/test_comment.py
index 56aaa20..f48bfce 100644
--- a/patchwork/tests/api/test_comment.py
+++ b/patchwork/tests/api/test_comment.py
@@ -17,15 +17,10 @@ from patchwork.tests.utils import SAMPLE_CONTENT
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestCoverComments(APITestCase):
+class TestCoverComments(utils.APITestCase):
@staticmethod
def api_url(cover, version=None):
kwargs = {}
@@ -76,7 +71,7 @@ class TestCoverComments(APITestCase):
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestPatchComments(APITestCase):
+class TestPatchComments(utils.APITestCase):
@staticmethod
def api_url(patch, version=None):
kwargs = {}
diff --git a/patchwork/tests/api/test_cover.py b/patchwork/tests/api/test_cover.py
index 8f96f38..0a0bf04 100644
--- a/patchwork/tests/api/test_cover.py
+++ b/patchwork/tests/api/test_cover.py
@@ -12,21 +12,14 @@ from django.urls import reverse
from patchwork.tests.api import utils
from patchwork.tests.utils import create_cover
from patchwork.tests.utils import create_maintainer
-from patchwork.tests.utils import create_person
-from patchwork.tests.utils import create_project
from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestCoverLetterAPI(APITestCase):
+class TestCoverLetterAPI(utils.APITestCase):
fixtures = ['default_tags']
@staticmethod
diff --git a/patchwork/tests/api/test_event.py b/patchwork/tests/api/test_event.py
index a2e89f5..8816538 100644
--- a/patchwork/tests/api/test_event.py
+++ b/patchwork/tests/api/test_event.py
@@ -19,15 +19,10 @@ from patchwork.tests.utils import create_state
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestEventAPI(APITestCase):
+class TestEventAPI(utils.APITestCase):
@staticmethod
def api_url(version=None):
diff --git a/patchwork/tests/api/test_patch.py b/patchwork/tests/api/test_patch.py
index b501392..82ae018 100644
--- a/patchwork/tests/api/test_patch.py
+++ b/patchwork/tests/api/test_patch.py
@@ -21,15 +21,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestPatchAPI(APITestCase):
+class TestPatchAPI(utils.APITestCase):
fixtures = ['default_tags']
@staticmethod
diff --git a/patchwork/tests/api/test_person.py b/patchwork/tests/api/test_person.py
index aad37a7..6bd3cb6 100644
--- a/patchwork/tests/api/test_person.py
+++ b/patchwork/tests/api/test_person.py
@@ -15,15 +15,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestPersonAPI(APITestCase):
+class TestPersonAPI(utils.APITestCase):
@staticmethod
def api_url(item=None):
diff --git a/patchwork/tests/api/test_project.py b/patchwork/tests/api/test_project.py
index 77ac0b4..5a76767 100644
--- a/patchwork/tests/api/test_project.py
+++ b/patchwork/tests/api/test_project.py
@@ -16,15 +16,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestProjectAPI(APITestCase):
+class TestProjectAPI(utils.APITestCase):
@staticmethod
def api_url(item=None, version=None):
diff --git a/patchwork/tests/api/test_series.py b/patchwork/tests/api/test_series.py
index aecd8b0..1327912 100644
--- a/patchwork/tests/api/test_series.py
+++ b/patchwork/tests/api/test_series.py
@@ -19,15 +19,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestSeriesAPI(APITestCase):
+class TestSeriesAPI(utils.APITestCase):
fixtures = ['default_tags']
@staticmethod
diff --git a/patchwork/tests/api/test_user.py b/patchwork/tests/api/test_user.py
index c6114ee..dfc4ddf 100644
--- a/patchwork/tests/api/test_user.py
+++ b/patchwork/tests/api/test_user.py
@@ -14,15 +14,10 @@ from patchwork.tests.utils import create_user
if settings.ENABLE_REST_API:
from rest_framework import status
- from rest_framework.test import APITestCase
-else:
- # stub out APITestCase
- from django.test import TestCase
- APITestCase = TestCase # noqa
@unittest.skipUnless(settings.ENABLE_REST_API, 'requires ENABLE_REST_API')
-class TestUserAPI(APITestCase):
+class TestUserAPI(utils.APITestCase):
@staticmethod
def api_url(item=None):
diff --git a/patchwork/tests/api/utils.py b/patchwork/tests/api/utils.py
index 1097bb0..0c232d0 100644
--- a/patchwork/tests/api/utils.py
+++ b/patchwork/tests/api/utils.py
@@ -7,7 +7,19 @@ import functools
import json
import os
-# docs/examples
+from django.conf import settings
+from django.test import testcases
+
+from patchwork.tests.api import validator
+
+if settings.ENABLE_REST_API:
+ from rest_framework.test import APIClient as BaseAPIClient
+ from rest_framework.test import APIRequestFactory
+else:
+ from django.test import Client as BaseAPIClient
+
+
+# docs/api/samples
OUT_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir,
os.pardir, 'docs', 'api', 'samples')
@@ -91,3 +103,55 @@ def store_samples(filename):
return wrapper
return inner
+
+
+class APIClient(BaseAPIClient):
+
+ def __init__(self, *args, **kwargs):
+ super(APIClient, self).__init__(*args, **kwargs)
+ self.factory = APIRequestFactory()
+
+ def get(self, path, data=None, follow=False, **extra):
+ request = self.factory.get(
+ path, data=data, SERVER_NAME='example.com', **extra)
+ response = super(APIClient, self).get(
+ path, data=data, follow=follow, SERVER_NAME='example.com', **extra)
+ validator.validate_data(path, request, response)
+ return response
+
+ def post(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ request = self.factory.post(
+ path, data=data, format='json', content_type=content_type,
+ SERVER_NAME='example.com', **extra)
+ response = super(APIClient, self).post(
+ path, data=data, format='json', content_type=content_type,
+ follow=follow, SERVER_NAME='example.com', **extra)
+ validator.validate_data(path, request, response)
+ return response
+
+ def put(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ request = self.factory.put(
+ path, data=data, format='json', content_type=content_type,
+ SERVER_NAME='example.com', **extra)
+ response = super(APIClient, self).put(
+ path, data=data, format='json', content_type=content_type,
+ follow=follow, SERVER_NAME='example.com', **extra)
+ validator.validate_data(path, request, response)
+ return response
+
+ def patch(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ request = self.factory.patch(
+ path, data=data, format='json', content_type=content_type,
+ SERVER_NAME='example.com', **extra)
+ response = super(APIClient, self).patch(
+ path, data=data, format='json', content_type=content_type,
+ follow=follow, SERVER_NAME='example.com', **extra)
+ validator.validate_data(path, request, response)
+ return response
+
+
+class APITestCase(testcases.TestCase):
+ client_class = APIClient
diff --git a/patchwork/tests/api/validator.py b/patchwork/tests/api/validator.py
new file mode 100644
index 0000000..3f13847
--- /dev/null
+++ b/patchwork/tests/api/validator.py
@@ -0,0 +1,317 @@
+# Patchwork - automated patch tracking system
+# Copyright (C) 2018 Stephen Finucane <stephen@that.guru>
+#
+# SPDX-License-Identifier: GPL-2.0-or-later
+
+import os
+import re
+
+import django
+from django.urls import resolve
+from django.urls.resolvers import get_resolver
+from django.utils import six
+import openapi_core
+from openapi_core.schema.schemas.models import Format
+from openapi_core.wrappers.base import BaseOpenAPIResponse
+from openapi_core.wrappers.base import BaseOpenAPIRequest
+from openapi_core.validation.request.validators import RequestValidator
+from openapi_core.validation.response.validators import ResponseValidator
+from openapi_core.schema.parameters.exceptions import OpenAPIParameterError
+from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError
+from rest_framework import status
+import yaml
+
+# docs/api/schemas
+SCHEMAS_DIR = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir,
+ os.pardir, 'docs', 'api', 'schemas')
+
+HEADER_REGEXES = (
+ re.compile(r'^HTTP_.+$'), re.compile(r'^CONTENT_TYPE$'),
+ re.compile(r'^CONTENT_LENGTH$'))
+
+_LOADED_SPECS = {}
+
+
+class RegexValidator(object):
+
+ def __init__(self, regex):
+ self.regex = re.compile(regex, re.IGNORECASE)
+
+ def __call__(self, value):
+ if not isinstance(value, six.text_type):
+ return False
+
+ if not value:
+ return True
+
+ return self.regex.match(value)
+
+
+CUSTOM_FORMATTERS = {
+ 'uri': Format(six.text_type, RegexValidator(
+ r'^(?:http|ftp)s?://'
+ r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # noqa
+ r'localhost|'
+ r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
+ r'(?::\d+)?'
+ r'(?:/?|[/?]\S+)$')),
+ 'iso8601': Format(six.text_type, RegexValidator(
+ r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d{6}$')),
+ 'email': Format(six.text_type, RegexValidator(
+ r'[^@]+@[^@]+\.[^@]+')),
+}
+
+
+def _extract_headers(request):
+ request_headers = {}
+ for header in request.META:
+ for regex in HEADER_REGEXES:
+ if regex.match(header):
+ request_headers[header] = request.META[header]
+
+ return request_headers
+
+
+def _resolve_django1x(path, resolver=None):
+ """Resolve a given path to its matching regex (Django 1.x).
+
+ This is essentially a re-implementation of ``RegexURLResolver.resolve``
+ that builds and returns the matched regex instead of the view itself.
+
+ >>> _resolve_django1x('/api/1.0/patches/1/checks/')
+ "^api/(?:(?P<version>(1.0|1.1))/)patches/(?P<patch_id>[^/]+)/checks/$"
+ """
+ from django.urls.resolvers import RegexURLResolver # noqa
+
+ resolver = resolver or get_resolver()
+ match = resolver.regex.search(path)
+
+ if not match:
+ return
+
+ if isinstance(resolver, RegexURLResolver):
+ sub_path = path[match.end():]
+ for sub_resolver in resolver.url_patterns:
+ sub_match = _resolve_django1x(sub_path, sub_resolver)
+ if not sub_match:
+ continue
+
+ kwargs = dict(match.groupdict())
+ kwargs.update(sub_match[2])
+ args = sub_match[1]
+ if not kwargs:
+ args = match.groups() + args
+
+ regex = resolver.regex.pattern + sub_match[0].lstrip('^')
+
+ return regex, args, kwargs
+ else: # RegexURLPattern
+ kwargs = match.groupdict()
+ args = () if kwargs else match.groups()
+ return resolver.regex.pattern, args, kwargs
+
+
+def _resolve_django2x(path, resolver=None):
+ """Resolve a given path to its matching regex (Django 2.x).
+
+ This is essentially a re-implementation of ``URLResolver.resolve`` that
+ builds and returns the matched regex instead of the view itself.
+
+ >>> _resolve_django2x('/api/1.0/patches/1/checks/')
+ "^api/(?:(?P<version>(1.0|1.1))/)patches/(?P<patch_id>[^/]+)/checks/$"
+ """
+ from django.urls.resolvers import URLResolver # noqa
+ from django.urls.resolvers import RegexPattern # noqa
+
+ resolver = resolver or get_resolver()
+ match = resolver.pattern.match(path)
+
+ # we dont handle any other type of pattern at the moment
+ assert isinstance(resolver.pattern, RegexPattern)
+
+ if not match:
+ return
+
+ if isinstance(resolver, URLResolver):
+ sub_path, args, kwargs = match
+ for sub_resolver in resolver.url_patterns:
+ sub_match = _resolve_django2x(sub_path, sub_resolver)
+ if not sub_match:
+ continue
+
+ kwargs.update(sub_match[2])
+ args += sub_match[1]
+
+ regex = resolver.pattern._regex + sub_match[0].lstrip('^')
+
+ return regex, args, kwargs
+ else:
+ _, args, kwargs = match
+ return resolver.pattern._regex, args, kwargs
+
+
+if django.VERSION < (2, 0):
+ _resolve = _resolve_django1x
+else:
+ _resolve = _resolve_django2x
+
+
+def _resolve_path_to_kwargs(path):
+ """Convert a path to the kwargs used to resolve it.
+
+ >>> resolve_path_to_kwargs('/api/1.0/patches/1/checks/')
+ {"patch_id": 1}
+ """
+ # TODO(stephenfin): Handle definition by args
+ _, _, kwargs = _resolve(path)
+
+ results = {}
+ for key, value in kwargs.items():
+ if key == 'version':
+ continue
+
+ if key == 'pk':
+ key = 'id'
+
+ results[key] = value
+
+ return results
+
+
+def _resolve_path_to_template(path):
+ """Convert a path to a template string.
+
+ >>> resolve_path_to_template('/api/1.0/patches/1/checks/')
+ "/api/{version}/patches/{patch_id}/checks/"
+ """
+ regex, _, _ = _resolve(path)
+ regex = re.match(regex, path)
+
+ result = ''
+ prev_index = 0
+ for index, group in enumerate(regex.groups(), 1):
+ if not group: # group didn't match anything
+ continue
+
+ result += path[prev_index:regex.start(index)]
+ prev_index = regex.end(index)
+ # groupindex keys by name, not index. Switch that.
+ for name, index_ in regex.re.groupindex.items():
+ if index_ == (index):
+ # special-case version group
+ if name == 'version':
+ result += group
+ break
+
+ if name == 'pk':
+ name = 'id'
+
+ result += '{%s}' % name
+ break
+
+ result += path[prev_index:]
+
+ return result
+
+
+def _load_spec(version):
+ global _LOADED_SPECS
+
+ if _LOADED_SPECS.get(version):
+ return _LOADED_SPECS[version]
+
+ spec_path = os.path.join(SCHEMAS_DIR,
+ 'v{}'.format(version) if version else 'latest',
+ 'patchwork.yaml')
+
+ with open(spec_path, 'r') as fh:
+ data = yaml.load(fh)
+
+ _LOADED_SPECS[version] = openapi_core.create_spec(data)
+
+ return _LOADED_SPECS[version]
+
+
+class DRFOpenAPIRequest(BaseOpenAPIRequest):
+
+ def __init__(self, request):
+ self.request = request
+
+ @property
+ def host_url(self):
+ return self.request.get_host()
+
+ @property
+ def path(self):
+ return self.request.path
+
+ @property
+ def method(self):
+ return self.request.method.lower()
+
+ @property
+ def path_pattern(self):
+ return _resolve_path_to_template(self.request.path_info)
+
+ @property
+ def parameters(self):
+ return {
+ 'path': _resolve_path_to_kwargs(self.request.path_info),
+ 'query': self.request.GET,
+ 'header': _extract_headers(self.request),
+ 'cookie': self.request.COOKIES,
+ }
+
+ @property
+ def body(self):
+ return self.request.body.decode('utf-8')
+
+ @property
+ def mimetype(self):
+ return self.request.content_type
+
+
+class DRFOpenAPIResponse(BaseOpenAPIResponse):
+
+ def __init__(self, response):
+ self.response = response
+
+ @property
+ def data(self):
+ return self.response.content.decode('utf-8')
+
+ @property
+ def status_code(self):
+ return self.response.status_code
+
+ @property
+ def mimetype(self):
+ # TODO(stephenfin): Why isn't this populated?
+ return 'application/json'
+
+
+def validate_data(path, request, response):
+ if response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED:
+ return
+
+ spec = _load_spec(resolve(path).kwargs.get('version'))
+ request = DRFOpenAPIRequest(request)
+ response = DRFOpenAPIResponse(response)
+
+ # request
+ validator = RequestValidator(spec, custom_formatters=CUSTOM_FORMATTERS)
+ result = validator.validate(request)
+ try:
+ result.raise_for_errors()
+ except OpenAPIMediaTypeError:
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ except OpenAPIParameterError:
+ # TODO(stephenfin): In API v2.0, this should be an error. As things
+ # stand, we silently ignore these issues.
+ assert response.status_code == status.HTTP_200_OK
+
+ # response
+ validator = ResponseValidator(spec, custom_formatters=CUSTOM_FORMATTERS)
+ result = validator.validate(request, response)
+ result.raise_for_errors()
diff --git a/requirements-test.txt b/requirements-test.txt
index 6c9bd88..cfb8ce7 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -2,3 +2,4 @@ mysqlclient==1.3.13
psycopg2-binary==2.7.6
sqlparse==0.2.4
python-dateutil==2.7.5
+openapi-core==0.7.1