mirror of
https://github.com/correl/openapi-core.git
synced 2024-11-22 03:00:10 +00:00
Base validator class
This commit is contained in:
parent
24276bf37a
commit
8539d9b0f4
4 changed files with 95 additions and 120 deletions
|
@ -22,17 +22,10 @@ from openapi_core.validation.request.datatypes import (
|
|||
RequestParameters, RequestValidationResult,
|
||||
)
|
||||
from openapi_core.validation.util import get_operation_pattern
|
||||
from openapi_core.validation.validators import BaseValidator
|
||||
|
||||
|
||||
class RequestValidator(object):
|
||||
|
||||
def __init__(
|
||||
self, spec,
|
||||
custom_formatters=None, custom_media_type_deserializers=None,
|
||||
):
|
||||
self.spec = spec
|
||||
self.custom_formatters = custom_formatters
|
||||
self.custom_media_type_deserializers = custom_media_type_deserializers
|
||||
class RequestValidator(BaseValidator):
|
||||
|
||||
def validate(self, request):
|
||||
try:
|
||||
|
@ -81,25 +74,6 @@ class RequestValidator(object):
|
|||
body, body_errors = self._get_body(request, operation)
|
||||
return RequestValidationResult(body_errors, body, None, None)
|
||||
|
||||
def _get_operation_pattern(self, request):
|
||||
server = self.spec.get_server(request.full_url_pattern)
|
||||
|
||||
return get_operation_pattern(
|
||||
server.default_url, request.full_url_pattern
|
||||
)
|
||||
|
||||
def _find_path(self, request):
|
||||
operation_pattern = self._get_operation_pattern(request)
|
||||
|
||||
path = self.spec[operation_pattern]
|
||||
path_variables = {}
|
||||
operation = self.spec.get_operation(operation_pattern, request.method)
|
||||
servers = path.servers or operation.servers or self.spec.servers
|
||||
server = servers[0]
|
||||
server_variables = {}
|
||||
|
||||
return path, operation, server, path_variables, server_variables
|
||||
|
||||
def _get_security(self, request, operation):
|
||||
security = operation.security or self.spec.security
|
||||
if not security:
|
||||
|
@ -222,15 +196,6 @@ class RequestValidator(object):
|
|||
raise MissingRequestBody(request)
|
||||
return request.body
|
||||
|
||||
def _deserialise_media_type(self, media_type, value):
|
||||
from openapi_core.deserializing.media_types.factories import (
|
||||
MediaTypeDeserializersFactory,
|
||||
)
|
||||
deserializers_factory = MediaTypeDeserializersFactory(
|
||||
self.custom_media_type_deserializers)
|
||||
deserializer = deserializers_factory.create(media_type)
|
||||
return deserializer(value)
|
||||
|
||||
def _deserialise_parameter(self, param, value):
|
||||
from openapi_core.deserializing.parameters.factories import (
|
||||
ParameterDeserializersFactory,
|
||||
|
@ -239,27 +204,7 @@ class RequestValidator(object):
|
|||
deserializer = deserializers_factory.create(param)
|
||||
return deserializer(value)
|
||||
|
||||
def _cast(self, param_or_media_type, value):
|
||||
# return param_or_media_type.cast(value)
|
||||
if not param_or_media_type.schema:
|
||||
return value
|
||||
|
||||
from openapi_core.casting.schemas.factories import SchemaCastersFactory
|
||||
casters_factory = SchemaCastersFactory()
|
||||
caster = casters_factory.create(param_or_media_type.schema)
|
||||
return caster(value)
|
||||
|
||||
def _unmarshal(self, param_or_media_type, value):
|
||||
if not param_or_media_type.schema:
|
||||
return value
|
||||
|
||||
from openapi_core.unmarshalling.schemas.factories import (
|
||||
SchemaUnmarshallersFactory,
|
||||
return super(RequestValidator, self)._unmarshal(
|
||||
param_or_media_type, value, context=UnmarshalContext.REQUEST,
|
||||
)
|
||||
unmarshallers_factory = SchemaUnmarshallersFactory(
|
||||
self.spec._resolver, self.custom_formatters,
|
||||
context=UnmarshalContext.REQUEST,
|
||||
)
|
||||
unmarshaller = unmarshallers_factory.create(
|
||||
param_or_media_type.schema)
|
||||
return unmarshaller(value)
|
||||
|
|
|
@ -3,6 +3,7 @@ from openapi_core.casting.schemas.exceptions import CastError
|
|||
from openapi_core.deserializing.exceptions import DeserializeError
|
||||
from openapi_core.schema.operations.exceptions import InvalidOperation
|
||||
from openapi_core.schema.media_types.exceptions import InvalidContentType
|
||||
from openapi_core.schema.paths.exceptions import InvalidPath
|
||||
from openapi_core.schema.responses.exceptions import (
|
||||
InvalidResponse, MissingResponseContent,
|
||||
)
|
||||
|
@ -13,24 +14,23 @@ from openapi_core.unmarshalling.schemas.exceptions import (
|
|||
)
|
||||
from openapi_core.validation.response.datatypes import ResponseValidationResult
|
||||
from openapi_core.validation.util import get_operation_pattern
|
||||
from openapi_core.validation.validators import BaseValidator
|
||||
|
||||
|
||||
class ResponseValidator(object):
|
||||
|
||||
def __init__(
|
||||
self, spec,
|
||||
custom_formatters=None, custom_media_type_deserializers=None,
|
||||
):
|
||||
self.spec = spec
|
||||
self.custom_formatters = custom_formatters
|
||||
self.custom_media_type_deserializers = custom_media_type_deserializers
|
||||
class ResponseValidator(BaseValidator):
|
||||
|
||||
def validate(self, request, response):
|
||||
try:
|
||||
operation_response = self._get_operation_response(
|
||||
request, response)
|
||||
_, operation, _, _, _ = self._find_path(request)
|
||||
# don't process if operation errors
|
||||
except (InvalidServer, InvalidOperation, InvalidResponse) as exc:
|
||||
except (InvalidServer, InvalidPath, InvalidOperation) as exc:
|
||||
return ResponseValidationResult([exc, ], None, None)
|
||||
|
||||
try:
|
||||
operation_response = self._get_operation_response(
|
||||
operation, response)
|
||||
# don't process if operation errors
|
||||
except InvalidResponse as exc:
|
||||
return ResponseValidationResult([exc, ], None, None)
|
||||
|
||||
data, data_errors = self._get_data(response, operation_response)
|
||||
|
@ -41,28 +41,21 @@ class ResponseValidator(object):
|
|||
errors = data_errors + headers_errors
|
||||
return ResponseValidationResult(errors, data, headers)
|
||||
|
||||
def _get_operation_pattern(self, request):
|
||||
server = self.spec.get_server(request.full_url_pattern)
|
||||
|
||||
return get_operation_pattern(
|
||||
server.default_url, request.full_url_pattern
|
||||
)
|
||||
|
||||
def _get_operation(self, request):
|
||||
operation_pattern = self._get_operation_pattern(request)
|
||||
|
||||
return self.spec.get_operation(operation_pattern, request.method)
|
||||
|
||||
def _get_operation_response(self, request, response):
|
||||
operation = self._get_operation(request)
|
||||
|
||||
def _get_operation_response(self, operation, response):
|
||||
return operation.get_response(str(response.status_code))
|
||||
|
||||
def _validate_data(self, request, response):
|
||||
try:
|
||||
_, operation, _, _, _ = self._find_path(request)
|
||||
# don't process if operation errors
|
||||
except (InvalidServer, InvalidPath, InvalidOperation) as exc:
|
||||
return ResponseValidationResult([exc, ], None, None)
|
||||
|
||||
try:
|
||||
operation_response = self._get_operation_response(
|
||||
request, response)
|
||||
except (InvalidServer, InvalidOperation, InvalidResponse) as exc:
|
||||
operation, response)
|
||||
# don't process if operation errors
|
||||
except InvalidResponse as exc:
|
||||
return ResponseValidationResult([exc, ], None, None)
|
||||
|
||||
data, data_errors = self._get_data(response, operation_response)
|
||||
|
@ -113,36 +106,7 @@ class ResponseValidator(object):
|
|||
|
||||
return response.data
|
||||
|
||||
def _deserialise_media_type(self, media_type, value):
|
||||
from openapi_core.deserializing.media_types.factories import (
|
||||
MediaTypeDeserializersFactory,
|
||||
)
|
||||
deserializers_factory = MediaTypeDeserializersFactory(
|
||||
self.custom_media_type_deserializers)
|
||||
deserializer = deserializers_factory.create(media_type)
|
||||
return deserializer(value)
|
||||
|
||||
def _cast(self, param_or_media_type, value):
|
||||
# return param_or_media_type.cast(value)
|
||||
if not param_or_media_type.schema:
|
||||
return value
|
||||
|
||||
from openapi_core.casting.schemas.factories import SchemaCastersFactory
|
||||
casters_factory = SchemaCastersFactory()
|
||||
caster = casters_factory.create(param_or_media_type.schema)
|
||||
return caster(value)
|
||||
|
||||
def _unmarshal(self, param_or_media_type, value):
|
||||
if not param_or_media_type.schema:
|
||||
return value
|
||||
|
||||
from openapi_core.unmarshalling.schemas.factories import (
|
||||
SchemaUnmarshallersFactory,
|
||||
return super(ResponseValidator, self)._unmarshal(
|
||||
param_or_media_type, value, context=UnmarshalContext.RESPONSE,
|
||||
)
|
||||
unmarshallers_factory = SchemaUnmarshallersFactory(
|
||||
self.spec._resolver, self.custom_formatters,
|
||||
context=UnmarshalContext.RESPONSE,
|
||||
)
|
||||
unmarshaller = unmarshallers_factory.create(
|
||||
param_or_media_type.schema)
|
||||
return unmarshaller(value)
|
||||
|
|
66
openapi_core/validation/validators.py
Normal file
66
openapi_core/validation/validators.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
"""OpenAPI core validation validators module"""
|
||||
from openapi_core.validation.util import get_operation_pattern
|
||||
|
||||
|
||||
class BaseValidator(object):
|
||||
|
||||
def __init__(
|
||||
self, spec,
|
||||
custom_formatters=None, custom_media_type_deserializers=None,
|
||||
):
|
||||
self.spec = spec
|
||||
self.custom_formatters = custom_formatters
|
||||
self.custom_media_type_deserializers = custom_media_type_deserializers
|
||||
|
||||
def _find_path(self, request):
|
||||
operation_pattern = self._get_operation_pattern(request)
|
||||
|
||||
path = self.spec[operation_pattern]
|
||||
path_variables = {}
|
||||
operation = self.spec.get_operation(operation_pattern, request.method)
|
||||
servers = path.servers or operation.servers or self.spec.servers
|
||||
server = servers[0]
|
||||
server_variables = {}
|
||||
|
||||
return path, operation, server, path_variables, server_variables
|
||||
|
||||
def _get_operation_pattern(self, request):
|
||||
server = self.spec.get_server(request.full_url_pattern)
|
||||
|
||||
return get_operation_pattern(
|
||||
server.default_url, request.full_url_pattern
|
||||
)
|
||||
|
||||
def _deserialise_media_type(self, media_type, value):
|
||||
from openapi_core.deserializing.media_types.factories import (
|
||||
MediaTypeDeserializersFactory,
|
||||
)
|
||||
deserializers_factory = MediaTypeDeserializersFactory(
|
||||
self.custom_media_type_deserializers)
|
||||
deserializer = deserializers_factory.create(media_type)
|
||||
return deserializer(value)
|
||||
|
||||
def _cast(self, param_or_media_type, value):
|
||||
# return param_or_media_type.cast(value)
|
||||
if not param_or_media_type.schema:
|
||||
return value
|
||||
|
||||
from openapi_core.casting.schemas.factories import SchemaCastersFactory
|
||||
casters_factory = SchemaCastersFactory()
|
||||
caster = casters_factory.create(param_or_media_type.schema)
|
||||
return caster(value)
|
||||
|
||||
def _unmarshal(self, param_or_media_type, value, context):
|
||||
if not param_or_media_type.schema:
|
||||
return value
|
||||
|
||||
from openapi_core.unmarshalling.schemas.factories import (
|
||||
SchemaUnmarshallersFactory,
|
||||
)
|
||||
unmarshallers_factory = SchemaUnmarshallersFactory(
|
||||
self.spec._resolver, self.custom_formatters,
|
||||
context=context,
|
||||
)
|
||||
unmarshaller = unmarshallers_factory.create(
|
||||
param_or_media_type.schema)
|
||||
return unmarshaller(value)
|
|
@ -439,7 +439,7 @@ class TestResponseValidator(object):
|
|||
result = validator.validate(request, response)
|
||||
|
||||
assert len(result.errors) == 1
|
||||
assert type(result.errors[0]) == InvalidOperation
|
||||
assert type(result.errors[0]) == InvalidPath
|
||||
assert result.data is None
|
||||
assert result.headers is None
|
||||
|
||||
|
|
Loading…
Reference in a new issue