summaryrefslogtreecommitdiff
path: root/patchwork/tests/api/validator.py
blob: 8ae8918260e6c5d35d95c8a956a7b767a0e4028f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Patchwork - automated patch tracking system
# Copyright (C) 2018 Stephen Finucane <stephen@that.guru>
#
# SPDX-License-Identifier: GPL-2.0-or-later

import os
import re

from django.urls import resolve
import openapi_core
from openapi_core.contrib.django import DjangoOpenAPIResponseFactory
from openapi_core.contrib.django import DjangoOpenAPIRequestFactory
from openapi_core.schema.schemas.models import Format
from openapi_core.validation.request.validators import RequestValidator
from openapi_core.validation.response.validators import ResponseValidator
from openapi_core.schema.parameters.exceptions import OpenAPIParameterError
from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError
from openapi_core.templating import util
from rest_framework import status
import yaml

# docs/api/schemas
SCHEMAS_DIR = os.path.join(
    os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir,
    os.pardir, 'docs', 'api', 'schemas')

_LOADED_SPECS = {}


# HACK! Workaround for https://github.com/p1c2u/openapi-core/issues/226
def search(path_pattern, full_url_pattern):
    p = util.Parser(path_pattern)
    p._expression = p._expression + '$'
    result = p.search(full_url_pattern)
    if not result or any('/' in arg for arg in result.named.values()):
        return None

    return result


util.search = search


class RegexValidator(object):

    def __init__(self, regex):
        self.regex = re.compile(regex, re.IGNORECASE)

    def __call__(self, value):
        if not isinstance(value, str):
            return False

        if not value:
            return True

        return self.regex.match(value)


CUSTOM_FORMATTERS = {
    'uri': Format(str, RegexValidator(
        r'^(?:http|ftp)s?://'
        r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|'  # noqa
        r'localhost|'
        r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
        r'(?::\d+)?'
        r'(?:/?|[/?]\S+)$')),
    'iso8601': Format(str, RegexValidator(
        r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d{6}$')),
    'email': Format(str, RegexValidator(
        r'[^@]+@[^@]+\.[^@]+')),
}


def _load_spec(version):
    global _LOADED_SPECS

    if _LOADED_SPECS.get(version):
        return _LOADED_SPECS[version]

    spec_path = os.path.join(SCHEMAS_DIR,
                             'v{}'.format(version) if version else 'latest',
                             'patchwork.yaml')

    with open(spec_path, 'r') as fh:
        data = yaml.load(fh, Loader=yaml.SafeLoader)

    _LOADED_SPECS[version] = openapi_core.create_spec(data)

    return _LOADED_SPECS[version]


def validate_data(path, request, response, validate_request,
                  validate_response):
    if response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED:
        return

    spec = _load_spec(resolve(path).kwargs.get('version'))
    request = DjangoOpenAPIRequestFactory.create(request)
    response = DjangoOpenAPIResponseFactory.create(response)

    # request
    if validate_request:
        validator = RequestValidator(
            spec, custom_formatters=CUSTOM_FORMATTERS)
        result = validator.validate(request)
        try:
            result.raise_for_errors()
        except OpenAPIMediaTypeError:
            if response.status_code != status.HTTP_400_BAD_REQUEST:
                raise
        except OpenAPIParameterError:
            # TODO(stephenfin): In API v2.0, this should be an error. As things
            # stand, we silently ignore these issues.
            assert response.status_code == status.HTTP_200_OK

    # response
    if validate_response:
        validator = ResponseValidator(
            spec, custom_formatters=CUSTOM_FORMATTERS)
        result = validator.validate(request, response)
        result.raise_for_errors()