Factories and exceptions cleanup

This commit is contained in:
Artur Maciag 2017-11-03 11:18:48 +00:00
parent e85d2d9f8a
commit 37f09d7571
12 changed files with 238 additions and 270 deletions

View file

@ -1,6 +1,5 @@
"""OpenAPI core module""" """OpenAPI core module"""
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
from openapi_core.wrappers import RequestParametersFactory, RequestBodyFactory
__author__ = 'Artur Maciąg' __author__ = 'Artur Maciąg'
__email__ = 'maciag.artur@gmail.com' __email__ = 'maciag.artur@gmail.com'
@ -8,7 +7,4 @@ __version__ = '0.2.2'
__url__ = 'https://github.com/p1c2u/openapi-core' __url__ = 'https://github.com/p1c2u/openapi-core'
__license__ = 'BSD 3-Clause License' __license__ = 'BSD 3-Clause License'
__all__ = ['create_spec', 'request_parameters_factory', 'request_body_factory'] __all__ = ['create_spec', ]
request_parameters_factory = RequestParametersFactory()
request_body_factory = RequestBodyFactory()

View file

@ -9,27 +9,11 @@ class OpenAPIMappingError(OpenAPIError):
pass pass
class MissingParameterError(OpenAPIMappingError): class OpenAPIServerError(OpenAPIMappingError):
pass pass
class MissingBodyError(OpenAPIMappingError): class OpenAPIOperationError(OpenAPIMappingError):
pass
class MissingPropertyError(OpenAPIMappingError):
pass
class InvalidContentTypeError(OpenAPIMappingError):
pass
class InvalidOperationError(OpenAPIMappingError):
pass
class InvalidServerError(OpenAPIMappingError):
pass pass
@ -37,13 +21,53 @@ class InvalidValueType(OpenAPIMappingError):
pass pass
class OpenAPIParameterError(OpenAPIMappingError):
pass
class OpenAPIBodyError(OpenAPIMappingError):
pass
class InvalidServer(OpenAPIServerError):
pass
class InvalidOperation(OpenAPIOperationError):
pass
class EmptyValue(OpenAPIParameterError):
pass
class MissingParameter(OpenAPIParameterError):
pass
class InvalidParameterValue(OpenAPIParameterError):
pass
class MissingBody(OpenAPIBodyError):
pass
class InvalidMediaTypeValue(OpenAPIBodyError):
pass
class UndefinedSchemaProperty(OpenAPIBodyError):
pass
class MissingProperty(OpenAPIBodyError):
pass
class InvalidContentType(OpenAPIBodyError):
pass
class InvalidValue(OpenAPIMappingError): class InvalidValue(OpenAPIMappingError):
pass pass
class EmptyValue(OpenAPIMappingError):
pass
class UndefinedSchemaProperty(OpenAPIMappingError):
pass

View file

@ -1,6 +1,8 @@
"""OpenAPI core mediaTypes module""" """OpenAPI core mediaTypes module"""
from six import iteritems from six import iteritems
from openapi_core.exceptions import InvalidValueType, InvalidMediaTypeValue
class MediaType(object): class MediaType(object):
"""Represents an OpenAPI MediaType.""" """Represents an OpenAPI MediaType."""
@ -13,7 +15,10 @@ class MediaType(object):
if not self.schema: if not self.schema:
return value return value
return self.schema.unmarshal(value) try:
return self.schema.unmarshal(value)
except InvalidValueType as exc:
raise InvalidMediaTypeValue(str(exc))
class MediaTypeGenerator(object): class MediaTypeGenerator(object):

View file

@ -2,7 +2,9 @@
import logging import logging
import warnings import warnings
from openapi_core.exceptions import EmptyValue from openapi_core.exceptions import (
EmptyValue, InvalidValueType, InvalidParameterValue,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,12 +37,15 @@ class Parameter(object):
if (self.location == "query" and value == "" and if (self.location == "query" and value == "" and
not self.allow_empty_value): not self.allow_empty_value):
raise EmptyValue( raise EmptyValue(
"Value of {0} parameter cannot be empty.".format(self.name)) "Value of {0} parameter cannot be empty".format(self.name))
if not self.schema: if not self.schema:
return value return value
return self.schema.unmarshal(value) try:
return self.schema.unmarshal(value)
except InvalidValueType as exc:
raise InvalidParameterValue(str(exc))
class ParametersGenerator(object): class ParametersGenerator(object):

View file

@ -1,7 +1,7 @@
"""OpenAPI core requestBodies module""" """OpenAPI core requestBodies module"""
from functools import lru_cache from functools import lru_cache
from openapi_core.exceptions import InvalidContentTypeError from openapi_core.exceptions import InvalidContentType
from openapi_core.media_types import MediaTypeGenerator from openapi_core.media_types import MediaTypeGenerator
@ -16,7 +16,7 @@ class RequestBody(object):
try: try:
return self.content[mimetype] return self.content[mimetype]
except KeyError: except KeyError:
raise InvalidContentTypeError( raise InvalidContentType(
"Invalid mime type `{0}`".format(mimetype)) "Invalid mime type `{0}`".format(mimetype))

View file

@ -10,8 +10,7 @@ from json import loads
from six import iteritems from six import iteritems
from openapi_core.exceptions import ( from openapi_core.exceptions import (
InvalidValueType, UndefinedSchemaProperty, MissingPropertyError, InvalidValueType, UndefinedSchemaProperty, MissingProperty, InvalidValue,
InvalidValue,
) )
from openapi_core.models import ModelFactory from openapi_core.models import ModelFactory
@ -59,7 +58,8 @@ class Schema(object):
if value is None: if value is None:
if not self.nullable: if not self.nullable:
raise InvalidValueType( raise InvalidValueType(
"Failed to cast value of %s to %s", value, self.type, "Failed to cast value of {0} to {1}".format(
value, self.type)
) )
return self.default return self.default
@ -73,7 +73,7 @@ class Schema(object):
return cast_callable(value) return cast_callable(value)
except ValueError: except ValueError:
raise InvalidValueType( raise InvalidValueType(
"Failed to cast value of %s to %s", value, self.type, "Failed to cast value of {0} to {1}".format(value, self.type)
) )
def unmarshal(self, value): def unmarshal(self, value):
@ -88,7 +88,8 @@ class Schema(object):
if self.enum and casted not in self.enum: if self.enum and casted not in self.enum:
raise InvalidValue( raise InvalidValue(
"Value of %s not in enum choices: %s", value, str(self.enum), "Value of {0} not in enum choices: {1}".format(
value, self.enum)
) )
return casted return casted
@ -115,7 +116,7 @@ class Schema(object):
prop_value = value[prop_name] prop_value = value[prop_name]
except KeyError: except KeyError:
if prop_name in self.required: if prop_name in self.required:
raise MissingPropertyError( raise MissingProperty(
"Missing schema property {0}".format(prop_name)) "Missing schema property {0}".format(prop_name))
if not prop.nullable and not prop.default: if not prop.nullable and not prop.default:
continue continue

View file

@ -6,7 +6,7 @@ from functools import partialmethod, lru_cache
from openapi_spec_validator import openapi_v3_spec_validator from openapi_spec_validator import openapi_v3_spec_validator
from openapi_core.components import ComponentsFactory from openapi_core.components import ComponentsFactory
from openapi_core.exceptions import InvalidOperationError, InvalidServerError from openapi_core.exceptions import InvalidOperation, InvalidServer
from openapi_core.infos import InfoFactory from openapi_core.infos import InfoFactory
from openapi_core.paths import PathsGenerator from openapi_core.paths import PathsGenerator
from openapi_core.schemas import SchemaRegistry from openapi_core.schemas import SchemaRegistry
@ -37,7 +37,7 @@ class Spec(object):
if spec_server.default_url in full_url_pattern: if spec_server.default_url in full_url_pattern:
return spec_server return spec_server
raise InvalidServerError( raise InvalidServer(
"Invalid request server {0}".format(full_url_pattern)) "Invalid request server {0}".format(full_url_pattern))
def get_server_url(self, index=0): def get_server_url(self, index=0):
@ -47,7 +47,7 @@ class Spec(object):
try: try:
return self.paths[path_pattern].operations[http_method] return self.paths[path_pattern].operations[http_method]
except KeyError: except KeyError:
raise InvalidOperationError( raise InvalidOperation(
"Unknown operation path {0} with method {1}".format( "Unknown operation path {0} with method {1}".format(
path_pattern, http_method)) path_pattern, http_method))

View file

@ -2,7 +2,7 @@
from six import iteritems from six import iteritems
from openapi_core.exceptions import ( from openapi_core.exceptions import (
OpenAPIMappingError, MissingParameterError, MissingBodyError, OpenAPIMappingError, MissingParameter, MissingBody,
) )
@ -45,13 +45,6 @@ class RequestValidationResult(BaseValidationResult):
class RequestValidator(object): class RequestValidator(object):
SPEC_LOCATION_TO_REQUEST_LOCATION = {
'path': 'view_args',
'query': 'args',
'headers': 'headers',
'cookies': 'cookies',
}
def __init__(self, spec): def __init__(self, spec):
self.spec = spec self.spec = spec
@ -69,10 +62,10 @@ class RequestValidator(object):
operation_pattern = request.full_url_pattern.replace( operation_pattern = request.full_url_pattern.replace(
server.default_url, '') server.default_url, '')
method = request.method.lower()
try: try:
operation = self.spec.get_operation(operation_pattern, method) operation = self.spec.get_operation(
operation_pattern, request.method)
# don't process if operation errors # don't process if operation errors
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)
@ -81,7 +74,7 @@ class RequestValidator(object):
for param_name, param in iteritems(operation.parameters): for param_name, param in iteritems(operation.parameters):
try: try:
raw_value = self._get_raw_value(request, param) raw_value = self._get_raw_value(request, param)
except MissingParameterError as exc: except MissingParameter as exc:
if param.required: if param.required:
errors.append(exc) errors.append(exc)
@ -89,9 +82,12 @@ class RequestValidator(object):
continue continue
raw_value = param.schema.default raw_value = param.schema.default
value = param.unmarshal(raw_value) try:
value = param.unmarshal(raw_value)
parameters[param.location][param_name] = value except OpenAPIMappingError as exc:
errors.append(exc)
else:
parameters[param.location][param_name] = value
if operation.request_body is not None: if operation.request_body is not None:
try: try:
@ -101,29 +97,26 @@ class RequestValidator(object):
else: else:
try: try:
raw_body = self._get_raw_body(request) raw_body = self._get_raw_body(request)
except MissingBodyError as exc: except MissingBody as exc:
if operation.request_body.required: if operation.request_body.required:
errors.append(exc) errors.append(exc)
else: else:
body = media_type.unmarshal(raw_body) try:
body = media_type.unmarshal(raw_body)
except OpenAPIMappingError as exc:
errors.append(exc)
return RequestValidationResult(errors, body, parameters) return RequestValidationResult(errors, body, parameters)
def _get_request_location(self, spec_location):
return self.SPEC_LOCATION_TO_REQUEST_LOCATION[spec_location]
def _get_raw_value(self, request, param): def _get_raw_value(self, request, param):
request_location = self._get_request_location(param.location)
request_attr = getattr(request, request_location)
try: try:
return request_attr[param.name] return request.parameters[param.location][param.name]
except KeyError: except KeyError:
raise MissingParameterError( raise MissingParameter(
"Missing required `{0}` parameter".format(param.name)) "Missing required `{0}` parameter".format(param.name))
def _get_raw_body(self, request): def _get_raw_body(self, request):
if not request.data: if not request.body:
raise MissingBodyError("Missing required request body") raise MissingBody("Missing required request body")
return request.data return request.body

View file

@ -1,110 +1,10 @@
"""OpenAPI core wrappers module""" """OpenAPI core wrappers module"""
from six import iteritems import warnings
from six.moves.urllib.parse import urljoin from six.moves.urllib.parse import urljoin
from openapi_core.exceptions import ( from openapi_core.exceptions import OpenAPIParameterError, OpenAPIBodyError
OpenAPIMappingError, MissingParameterError, InvalidContentTypeError, from openapi_core.validators import RequestValidator
InvalidServerError,
)
SPEC_LOCATION_TO_REQUEST_LOCATION_MAPPING = {
'path': 'view_args',
'query': 'args',
'headers': 'headers',
'cookies': 'cookies',
}
class RequestParameters(dict):
valid_locations = ['path', 'query', 'headers', 'cookies']
def __getitem__(self, location):
self.validate_location(location)
return self.setdefault(location, {})
def __setitem__(self, location, value):
raise NotImplementedError
@classmethod
def validate_location(cls, location):
if location not in cls.valid_locations:
raise OpenAPIMappingError(
"Unknown parameter location: {0}".format(str(location)))
class BaseRequestFactory(object):
def get_operation(self, request, spec):
server = self._get_server(request, spec)
operation_pattern = request.full_url_pattern.replace(
server.default_url, '')
method = request.method.lower()
return spec.get_operation(operation_pattern, method)
def _get_server(self, request, spec):
for server in spec.servers:
if server.default_url in request.full_url_pattern:
return server
raise InvalidServerError(
"Invalid request server {0}".format(request.full_url_pattern))
class RequestParametersFactory(BaseRequestFactory):
def __init__(self, attr_mapping=SPEC_LOCATION_TO_REQUEST_LOCATION_MAPPING):
self.attr_mapping = attr_mapping
def create(self, request, spec):
operation = self.get_operation(request, spec)
params = RequestParameters()
for param_name, param in iteritems(operation.parameters):
try:
raw_value = self._get_raw_value(request, param)
except MissingParameterError:
if param.required:
raise
if not param.schema or param.schema.default is None:
continue
raw_value = param.schema.default
value = param.unmarshal(raw_value)
params[param.location][param_name] = value
return params
def _get_raw_value(self, request, param):
request_location = self.attr_mapping[param.location]
request_attr = getattr(request, request_location)
try:
return request_attr[param.name]
except KeyError:
raise MissingParameterError(
"Missing required `{0}` parameter".format(param.name))
class RequestBodyFactory(BaseRequestFactory):
def create(self, request, spec):
operation = self.get_operation(request, spec)
if operation.request_body is None:
return None
try:
media_type = operation.request_body[request.mimetype]
except KeyError:
raise InvalidContentTypeError(
"Invalid media type `{0}`".format(request.mimetype))
return media_type.unmarshal(request.data)
class BaseOpenAPIRequest(object): class BaseOpenAPIRequest(object):
@ -114,12 +14,8 @@ class BaseOpenAPIRequest(object):
path_pattern = NotImplemented path_pattern = NotImplemented
method = NotImplemented method = NotImplemented
args = NotImplemented parameters = NotImplemented
view_args = NotImplemented body = NotImplemented
headers = NotImplemented
cookies = NotImplemented
data = NotImplemented
mimetype = NotImplemented mimetype = NotImplemented
@ -127,8 +23,96 @@ class BaseOpenAPIRequest(object):
def full_url_pattern(self): def full_url_pattern(self):
return urljoin(self.host_url, self.path_pattern) return urljoin(self.host_url, self.path_pattern)
def get_parameters(self, spec):
return RequestParametersFactory().create(self, spec)
def get_body(self, spec): def get_body(self, spec):
return RequestBodyFactory().create(self, spec) warnings.warn(
"`get_body` method is deprecated. "
"Use RequestValidator instead.",
DeprecationWarning,
)
# backward compatibility
validator = RequestValidator(spec)
result = validator.validate(self)
try:
result.validate()
except OpenAPIParameterError:
return result.body
else:
return result.body
def get_parameters(self, spec):
warnings.warn(
"`get_parameters` method is deprecated. "
"Use RequestValidator instead.",
DeprecationWarning,
)
# backward compatibility
validator = RequestValidator(spec)
result = validator.validate(self)
try:
result.validate()
except OpenAPIBodyError:
return result.parameters
else:
return result.parameters
class MockRequest(BaseOpenAPIRequest):
def __init__(
self, host_url, method, path, path_pattern=None, args=None,
view_args=None, headers=None, cookies=None, data=None,
mimetype='application/json'):
self.host_url = host_url
self.path = path
self.path_pattern = path_pattern or path
self.method = method.lower()
self.parameters = {
'path': view_args or {},
'query': args or {},
'headers': headers or {},
'cookies': cookies or {},
}
self.body = data or ''
self.mimetype = mimetype
class WerkzeugOpenAPIRequest(BaseOpenAPIRequest):
def __init__(self, request):
self.request = request
@property
def host_url(self):
return self.request.host_url
@property
def path(self):
return self.request.path
@property
def method(self):
return self.request.method.lower()
@property
def path_pattern(self):
return self.request.url_rule.rule
@property
def parameters(self):
return {
'path': self.request['view_args'],
'query': self.request['args'],
'headers': self.request['headers'],
'cookies': self.request['cookies'],
}
@property
def body(self):
return self.request.data
@property
def mimetype(self):
return self.request.mimetype

View file

@ -3,9 +3,9 @@ import pytest
from six import iteritems from six import iteritems
from openapi_core.exceptions import ( from openapi_core.exceptions import (
MissingParameterError, InvalidContentTypeError, InvalidServerError, MissingParameter, InvalidContentType, InvalidServer,
InvalidValueType, UndefinedSchemaProperty, MissingPropertyError, UndefinedSchemaProperty, MissingProperty,
EmptyValue, EmptyValue, InvalidMediaTypeValue, InvalidParameterValue,
) )
from openapi_core.media_types import MediaType from openapi_core.media_types import MediaType
from openapi_core.operations import Operation from openapi_core.operations import Operation
@ -14,27 +14,7 @@ from openapi_core.request_bodies import RequestBody
from openapi_core.schemas import Schema from openapi_core.schemas import Schema
from openapi_core.servers import Server, ServerVariable from openapi_core.servers import Server, ServerVariable
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
from openapi_core.wrappers import BaseOpenAPIRequest from openapi_core.wrappers import MockRequest
class RequestMock(BaseOpenAPIRequest):
def __init__(
self, host_url, method, path, path_pattern=None, args=None,
view_args=None, headers=None, cookies=None, data=None,
mimetype='application/json'):
self.host_url = host_url
self.path = path
self.path_pattern = path_pattern or path
self.method = method
self.args = args or {}
self.view_args = view_args or {}
self.headers = headers or {}
self.cookies = cookies or {}
self.data = data or ''
self.mimetype = mimetype
class TestPetstore(object): class TestPetstore(object):
@ -125,7 +105,7 @@ class TestPetstore(object):
'ids': ['12', '13'], 'ids': ['12', '13'],
} }
request = RequestMock( request = MockRequest(
host_url, 'GET', '/pets', host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params, path_pattern=path_pattern, args=query_params,
) )
@ -150,12 +130,12 @@ class TestPetstore(object):
'limit': 'twenty', 'limit': 'twenty',
} }
request = RequestMock( request = MockRequest(
host_url, 'GET', '/pets', host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params, path_pattern=path_pattern, args=query_params,
) )
with pytest.raises(InvalidValueType): with pytest.raises(InvalidParameterValue):
request.get_parameters(spec) request.get_parameters(spec)
body = request.get_body(spec) body = request.get_body(spec)
@ -165,12 +145,12 @@ class TestPetstore(object):
def test_get_pets_raises_missing_required_param(self, spec): def test_get_pets_raises_missing_required_param(self, spec):
host_url = 'http://petstore.swagger.io/v1' host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets' path_pattern = '/v1/pets'
request = RequestMock( request = MockRequest(
host_url, 'GET', '/pets', host_url, 'GET', '/pets',
path_pattern=path_pattern, path_pattern=path_pattern,
) )
with pytest.raises(MissingParameterError): with pytest.raises(MissingParameter):
request.get_parameters(spec) request.get_parameters(spec)
body = request.get_body(spec) body = request.get_body(spec)
@ -184,7 +164,7 @@ class TestPetstore(object):
'limit': '', 'limit': '',
} }
request = RequestMock( request = MockRequest(
host_url, 'GET', '/pets', host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params, path_pattern=path_pattern, args=query_params,
) )
@ -202,7 +182,7 @@ class TestPetstore(object):
'limit': None, 'limit': None,
} }
request = RequestMock( request = MockRequest(
host_url, 'GET', '/pets', host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params, path_pattern=path_pattern, args=query_params,
) )
@ -239,7 +219,7 @@ class TestPetstore(object):
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, path_pattern=path_pattern, data=data,
) )
@ -267,7 +247,7 @@ class TestPetstore(object):
data_json = {} data_json = {}
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, path_pattern=path_pattern, data=data,
) )
@ -276,7 +256,7 @@ class TestPetstore(object):
assert parameters == {} assert parameters == {}
with pytest.raises(MissingPropertyError): with pytest.raises(MissingProperty):
request.get_body(spec) request.get_body(spec)
def test_post_pets_extra_body_properties(self, spec, spec_dict): def test_post_pets_extra_body_properties(self, spec, spec_dict):
@ -290,7 +270,7 @@ class TestPetstore(object):
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, path_pattern=path_pattern, data=data,
) )
@ -311,7 +291,7 @@ class TestPetstore(object):
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, path_pattern=path_pattern, data=data,
) )
@ -342,7 +322,7 @@ class TestPetstore(object):
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, path_pattern=path_pattern, data=data,
) )
@ -351,7 +331,7 @@ class TestPetstore(object):
assert parameters == {} assert parameters == {}
with pytest.raises(InvalidValueType): with pytest.raises(InvalidMediaTypeValue):
request.get_body(spec) request.get_body(spec)
def test_post_pets_raises_invalid_mimetype(self, spec): def test_post_pets_raises_invalid_mimetype(self, spec):
@ -363,7 +343,7 @@ class TestPetstore(object):
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, mimetype='text/html', path_pattern=path_pattern, data=data, mimetype='text/html',
) )
@ -372,7 +352,7 @@ class TestPetstore(object):
assert parameters == {} assert parameters == {}
with pytest.raises(InvalidContentTypeError): with pytest.raises(InvalidContentType):
request.get_body(spec) request.get_body(spec)
def test_post_pets_raises_invalid_server_error(self, spec): def test_post_pets_raises_invalid_server_error(self, spec):
@ -384,15 +364,15 @@ class TestPetstore(object):
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
host_url, 'POST', '/pets', host_url, 'POST', '/pets',
path_pattern=path_pattern, data=data, mimetype='text/html', path_pattern=path_pattern, data=data, mimetype='text/html',
) )
with pytest.raises(InvalidServerError): with pytest.raises(InvalidServer):
request.get_parameters(spec) request.get_parameters(spec)
with pytest.raises(InvalidServerError): with pytest.raises(InvalidServer):
request.get_body(spec) request.get_body(spec)
def test_get_pet(self, spec): def test_get_pet(self, spec):
@ -401,7 +381,7 @@ class TestPetstore(object):
view_args = { view_args = {
'petId': '1', 'petId': '1',
} }
request = RequestMock( request = MockRequest(
host_url, 'GET', '/pets/1', host_url, 'GET', '/pets/1',
path_pattern=path_pattern, view_args=view_args, path_pattern=path_pattern, view_args=view_args,
) )

View file

@ -2,32 +2,12 @@ import json
import pytest import pytest
from openapi_core.exceptions import ( from openapi_core.exceptions import (
InvalidServerError, InvalidOperationError, MissingParameterError, InvalidServer, InvalidOperation, MissingParameter,
MissingBodyError, InvalidContentTypeError, MissingBody, InvalidContentType,
) )
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
from openapi_core.validators import RequestValidator from openapi_core.validators import RequestValidator
from openapi_core.wrappers import BaseOpenAPIRequest from openapi_core.wrappers import MockRequest
class RequestMock(BaseOpenAPIRequest):
def __init__(
self, host_url, method, path, path_pattern=None, args=None,
view_args=None, headers=None, cookies=None, data=None,
mimetype='application/json'):
self.host_url = host_url
self.path = path
self.path_pattern = path_pattern or path
self.method = method
self.args = args or {}
self.view_args = view_args or {}
self.headers = headers or {}
self.cookies = cookies or {}
self.data = data or ''
self.mimetype = mimetype
class TestRequestValidator(object): class TestRequestValidator(object):
@ -47,31 +27,31 @@ class TestRequestValidator(object):
return RequestValidator(spec) return RequestValidator(spec)
def test_request_server_error(self, validator): def test_request_server_error(self, validator):
request = RequestMock('http://petstore.invalid.net/v1', 'get', '/') request = MockRequest('http://petstore.invalid.net/v1', 'get', '/')
result = validator.validate(request) result = validator.validate(request)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert type(result.errors[0]) == InvalidServerError assert type(result.errors[0]) == InvalidServer
assert result.body is None assert result.body is None
assert result.parameters == {} assert result.parameters == {}
def test_invalid_operation(self, validator): def test_invalid_operation(self, validator):
request = RequestMock(self.host_url, 'get', '/v1') request = MockRequest(self.host_url, 'get', '/v1')
result = validator.validate(request) result = validator.validate(request)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert type(result.errors[0]) == InvalidOperationError assert type(result.errors[0]) == InvalidOperation
assert result.body is None assert result.body is None
assert result.parameters == {} assert result.parameters == {}
def test_missing_parameter(self, validator): def test_missing_parameter(self, validator):
request = RequestMock(self.host_url, 'get', '/v1/pets') request = MockRequest(self.host_url, 'get', '/v1/pets')
result = validator.validate(request) result = validator.validate(request)
assert type(result.errors[0]) == MissingParameterError assert type(result.errors[0]) == MissingParameter
assert result.body is None assert result.body is None
assert result.parameters == { assert result.parameters == {
'query': { 'query': {
@ -81,7 +61,7 @@ class TestRequestValidator(object):
} }
def test_get_pets(self, validator): def test_get_pets(self, validator):
request = RequestMock( request = MockRequest(
self.host_url, 'get', '/v1/pets', self.host_url, 'get', '/v1/pets',
path_pattern='/v1/pets', args={'limit': '10'}, path_pattern='/v1/pets', args={'limit': '10'},
) )
@ -99,7 +79,7 @@ class TestRequestValidator(object):
} }
def test_missing_body(self, validator): def test_missing_body(self, validator):
request = RequestMock( request = MockRequest(
self.host_url, 'post', '/v1/pets', self.host_url, 'post', '/v1/pets',
path_pattern='/v1/pets', path_pattern='/v1/pets',
) )
@ -107,12 +87,12 @@ class TestRequestValidator(object):
result = validator.validate(request) result = validator.validate(request)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert type(result.errors[0]) == MissingBodyError assert type(result.errors[0]) == MissingBody
assert result.body is None assert result.body is None
assert result.parameters == {} assert result.parameters == {}
def test_invalid_content_type(self, validator): def test_invalid_content_type(self, validator):
request = RequestMock( request = MockRequest(
self.host_url, 'post', '/v1/pets', self.host_url, 'post', '/v1/pets',
path_pattern='/v1/pets', mimetype='text/csv', path_pattern='/v1/pets', mimetype='text/csv',
) )
@ -120,7 +100,7 @@ class TestRequestValidator(object):
result = validator.validate(request) result = validator.validate(request)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert type(result.errors[0]) == InvalidContentTypeError assert type(result.errors[0]) == InvalidContentType
assert result.body is None assert result.body is None
assert result.parameters == {} assert result.parameters == {}
@ -139,7 +119,7 @@ class TestRequestValidator(object):
} }
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock( request = MockRequest(
self.host_url, 'post', '/v1/pets', self.host_url, 'post', '/v1/pets',
path_pattern='/v1/pets', data=data, path_pattern='/v1/pets', data=data,
) )
@ -161,7 +141,7 @@ class TestRequestValidator(object):
assert result.body.address.city == pet_city assert result.body.address.city == pet_city
def test_get_pet(self, validator): def test_get_pet(self, validator):
request = RequestMock( request = MockRequest(
self.host_url, 'get', '/v1/pets/1', self.host_url, 'get', '/v1/pets/1',
path_pattern='/v1/pets/{petId}', view_args={'petId': '1'}, path_pattern='/v1/pets/{petId}', view_args={'petId': '1'},
) )

View file

@ -1,7 +1,7 @@
import mock import mock
import pytest import pytest
from openapi_core.exceptions import InvalidOperationError from openapi_core.exceptions import InvalidOperation
from openapi_core.paths import Path from openapi_core.paths import Path
from openapi_core.specs import Spec from openapi_core.specs import Spec
@ -42,9 +42,9 @@ class TestSpecs(object):
assert operation == mock.sentinel.path1_get assert operation == mock.sentinel.path1_get
def test_invalid_path(self, spec): def test_invalid_path(self, spec):
with pytest.raises(InvalidOperationError): with pytest.raises(InvalidOperation):
spec.get_operation('/path3', 'get') spec.get_operation('/path3', 'get')
def test_invalid_method(self, spec): def test_invalid_method(self, spec):
with pytest.raises(InvalidOperationError): with pytest.raises(InvalidOperation):
spec.get_operation('/path1', 'post') spec.get_operation('/path1', 'post')