Base validator class

This commit is contained in:
Artur Maciag 2020-02-21 16:08:13 +00:00
parent 24276bf37a
commit 8539d9b0f4
4 changed files with 95 additions and 120 deletions

View file

@ -22,17 +22,10 @@ from openapi_core.validation.request.datatypes import (
RequestParameters, RequestValidationResult,
)
from openapi_core.validation.util import get_operation_pattern
from openapi_core.validation.validators import BaseValidator
class RequestValidator(object):
def __init__(
self, spec,
custom_formatters=None, custom_media_type_deserializers=None,
):
self.spec = spec
self.custom_formatters = custom_formatters
self.custom_media_type_deserializers = custom_media_type_deserializers
class RequestValidator(BaseValidator):
def validate(self, request):
try:
@ -81,25 +74,6 @@ class RequestValidator(object):
body, body_errors = self._get_body(request, operation)
return RequestValidationResult(body_errors, body, None, None)
def _get_operation_pattern(self, request):
server = self.spec.get_server(request.full_url_pattern)
return get_operation_pattern(
server.default_url, request.full_url_pattern
)
def _find_path(self, request):
operation_pattern = self._get_operation_pattern(request)
path = self.spec[operation_pattern]
path_variables = {}
operation = self.spec.get_operation(operation_pattern, request.method)
servers = path.servers or operation.servers or self.spec.servers
server = servers[0]
server_variables = {}
return path, operation, server, path_variables, server_variables
def _get_security(self, request, operation):
security = operation.security or self.spec.security
if not security:
@ -222,15 +196,6 @@ class RequestValidator(object):
raise MissingRequestBody(request)
return request.body
def _deserialise_media_type(self, media_type, value):
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
deserializers_factory = MediaTypeDeserializersFactory(
self.custom_media_type_deserializers)
deserializer = deserializers_factory.create(media_type)
return deserializer(value)
def _deserialise_parameter(self, param, value):
from openapi_core.deserializing.parameters.factories import (
ParameterDeserializersFactory,
@ -239,27 +204,7 @@ class RequestValidator(object):
deserializer = deserializers_factory.create(param)
return deserializer(value)
def _cast(self, param_or_media_type, value):
# return param_or_media_type.cast(value)
if not param_or_media_type.schema:
return value
from openapi_core.casting.schemas.factories import SchemaCastersFactory
casters_factory = SchemaCastersFactory()
caster = casters_factory.create(param_or_media_type.schema)
return caster(value)
def _unmarshal(self, param_or_media_type, value):
if not param_or_media_type.schema:
return value
from openapi_core.unmarshalling.schemas.factories import (
SchemaUnmarshallersFactory,
return super(RequestValidator, self)._unmarshal(
param_or_media_type, value, context=UnmarshalContext.REQUEST,
)
unmarshallers_factory = SchemaUnmarshallersFactory(
self.spec._resolver, self.custom_formatters,
context=UnmarshalContext.REQUEST,
)
unmarshaller = unmarshallers_factory.create(
param_or_media_type.schema)
return unmarshaller(value)

View file

@ -3,6 +3,7 @@ from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.schema.operations.exceptions import InvalidOperation
from openapi_core.schema.media_types.exceptions import InvalidContentType
from openapi_core.schema.paths.exceptions import InvalidPath
from openapi_core.schema.responses.exceptions import (
InvalidResponse, MissingResponseContent,
)
@ -13,24 +14,23 @@ from openapi_core.unmarshalling.schemas.exceptions import (
)
from openapi_core.validation.response.datatypes import ResponseValidationResult
from openapi_core.validation.util import get_operation_pattern
from openapi_core.validation.validators import BaseValidator
class ResponseValidator(object):
def __init__(
self, spec,
custom_formatters=None, custom_media_type_deserializers=None,
):
self.spec = spec
self.custom_formatters = custom_formatters
self.custom_media_type_deserializers = custom_media_type_deserializers
class ResponseValidator(BaseValidator):
def validate(self, request, response):
try:
operation_response = self._get_operation_response(
request, response)
_, operation, _, _, _ = self._find_path(request)
# don't process if operation errors
except (InvalidServer, InvalidOperation, InvalidResponse) as exc:
except (InvalidServer, InvalidPath, InvalidOperation) as exc:
return ResponseValidationResult([exc, ], None, None)
try:
operation_response = self._get_operation_response(
operation, response)
# don't process if operation errors
except InvalidResponse as exc:
return ResponseValidationResult([exc, ], None, None)
data, data_errors = self._get_data(response, operation_response)
@ -41,28 +41,21 @@ class ResponseValidator(object):
errors = data_errors + headers_errors
return ResponseValidationResult(errors, data, headers)
def _get_operation_pattern(self, request):
server = self.spec.get_server(request.full_url_pattern)
return get_operation_pattern(
server.default_url, request.full_url_pattern
)
def _get_operation(self, request):
operation_pattern = self._get_operation_pattern(request)
return self.spec.get_operation(operation_pattern, request.method)
def _get_operation_response(self, request, response):
operation = self._get_operation(request)
def _get_operation_response(self, operation, response):
return operation.get_response(str(response.status_code))
def _validate_data(self, request, response):
try:
_, operation, _, _, _ = self._find_path(request)
# don't process if operation errors
except (InvalidServer, InvalidPath, InvalidOperation) as exc:
return ResponseValidationResult([exc, ], None, None)
try:
operation_response = self._get_operation_response(
request, response)
except (InvalidServer, InvalidOperation, InvalidResponse) as exc:
operation, response)
# don't process if operation errors
except InvalidResponse as exc:
return ResponseValidationResult([exc, ], None, None)
data, data_errors = self._get_data(response, operation_response)
@ -113,36 +106,7 @@ class ResponseValidator(object):
return response.data
def _deserialise_media_type(self, media_type, value):
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
deserializers_factory = MediaTypeDeserializersFactory(
self.custom_media_type_deserializers)
deserializer = deserializers_factory.create(media_type)
return deserializer(value)
def _cast(self, param_or_media_type, value):
# return param_or_media_type.cast(value)
if not param_or_media_type.schema:
return value
from openapi_core.casting.schemas.factories import SchemaCastersFactory
casters_factory = SchemaCastersFactory()
caster = casters_factory.create(param_or_media_type.schema)
return caster(value)
def _unmarshal(self, param_or_media_type, value):
if not param_or_media_type.schema:
return value
from openapi_core.unmarshalling.schemas.factories import (
SchemaUnmarshallersFactory,
return super(ResponseValidator, self)._unmarshal(
param_or_media_type, value, context=UnmarshalContext.RESPONSE,
)
unmarshallers_factory = SchemaUnmarshallersFactory(
self.spec._resolver, self.custom_formatters,
context=UnmarshalContext.RESPONSE,
)
unmarshaller = unmarshallers_factory.create(
param_or_media_type.schema)
return unmarshaller(value)

View file

@ -0,0 +1,66 @@
"""OpenAPI core validation validators module"""
from openapi_core.validation.util import get_operation_pattern
class BaseValidator(object):
def __init__(
self, spec,
custom_formatters=None, custom_media_type_deserializers=None,
):
self.spec = spec
self.custom_formatters = custom_formatters
self.custom_media_type_deserializers = custom_media_type_deserializers
def _find_path(self, request):
operation_pattern = self._get_operation_pattern(request)
path = self.spec[operation_pattern]
path_variables = {}
operation = self.spec.get_operation(operation_pattern, request.method)
servers = path.servers or operation.servers or self.spec.servers
server = servers[0]
server_variables = {}
return path, operation, server, path_variables, server_variables
def _get_operation_pattern(self, request):
server = self.spec.get_server(request.full_url_pattern)
return get_operation_pattern(
server.default_url, request.full_url_pattern
)
def _deserialise_media_type(self, media_type, value):
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
deserializers_factory = MediaTypeDeserializersFactory(
self.custom_media_type_deserializers)
deserializer = deserializers_factory.create(media_type)
return deserializer(value)
def _cast(self, param_or_media_type, value):
# return param_or_media_type.cast(value)
if not param_or_media_type.schema:
return value
from openapi_core.casting.schemas.factories import SchemaCastersFactory
casters_factory = SchemaCastersFactory()
caster = casters_factory.create(param_or_media_type.schema)
return caster(value)
def _unmarshal(self, param_or_media_type, value, context):
if not param_or_media_type.schema:
return value
from openapi_core.unmarshalling.schemas.factories import (
SchemaUnmarshallersFactory,
)
unmarshallers_factory = SchemaUnmarshallersFactory(
self.spec._resolver, self.custom_formatters,
context=context,
)
unmarshaller = unmarshallers_factory.create(
param_or_media_type.schema)
return unmarshaller(value)

View file

@ -439,7 +439,7 @@ class TestResponseValidator(object):
result = validator.validate(request, response)
assert len(result.errors) == 1
assert type(result.errors[0]) == InvalidOperation
assert type(result.errors[0]) == InvalidPath
assert result.data is None
assert result.headers is None