servers with request validation

This commit is contained in:
Artur Maciag 2017-09-25 12:22:55 +01:00
parent 84546fee8e
commit d60bde446d
6 changed files with 219 additions and 22 deletions

View file

@ -15,3 +15,7 @@ class MissingParameterError(OpenAPIMappingError):
class InvalidContentTypeError(OpenAPIMappingError):
pass
class InvalidServerError(OpenAPIMappingError):
pass

76
openapi_core/servers.py Normal file
View file

@ -0,0 +1,76 @@
from functools import lru_cache
from six import iteritems
class Server(object):
def __init__(self, url, variables=None):
self.url = url
self.variables = variables and dict(variables) or {}
@property
def default_url(self):
return self.get_url()
@property
def default_variables(self):
defaults = {}
for name, variable in iteritems(self.variables):
defaults[name] = variable.default
return defaults
def get_url(self, **variables):
if not variables:
variables = self.default_variables
return self.url.format(**variables)
class ServerVariable(object):
def __init__(self, name, default, enum=None):
self.name = name
self.default = default
self.enum = enum and list(enum) or []
class ServersGenerator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def generate(self, servers_spec):
servers_deref = self.dereferencer.dereference(servers_spec)
for server_spec in servers_deref:
url = server_spec['url']
variables_spec = server_spec.get('variables', {})
variables = None
if variables_spec:
variables = self.variables_generator.generate(variables_spec)
yield Server(url, variables=variables)
@property
@lru_cache()
def variables_generator(self):
return ServerVariablesGenerator(self.dereferencer)
class ServerVariablesGenerator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def generate(self, variables_spec):
variables_deref = self.dereferencer.dereference(variables_spec)
if not variables_deref:
return [Server('/'), ]
for variable_name, variable_spec in iteritems(variables_deref):
default = variable_spec['default']
enum = variable_spec.get('enum')
variable = ServerVariable(variable_name, default, enum=enum)
yield variable_name, variable

View file

@ -9,6 +9,7 @@ from openapi_core.components import ComponentsFactory
from openapi_core.infos import InfoFactory
from openapi_core.paths import PathsGenerator
from openapi_core.schemas import SchemaRegistry
from openapi_core.servers import ServersGenerator
log = logging.getLogger(__name__)
@ -26,8 +27,12 @@ class Spec(object):
def __getitem__(self, path_name):
return self.paths[path_name]
@property
def default_url(self):
return self.servers[0].default_url
def get_server_url(self, index=0):
return self.servers[index]['url']
return self.servers[index].default_url
def get_operation(self, path_pattern, http_method):
return self.paths[path_pattern].operations[http_method]
@ -59,14 +64,16 @@ class SpecFactory(object):
spec_dict_deref = self.dereferencer.dereference(spec_dict)
info_spec = spec_dict_deref.get('info', [])
servers = spec_dict_deref.get('servers', [])
servers_spec = spec_dict_deref.get('servers', [])
paths = spec_dict_deref.get('paths', [])
components_spec = spec_dict_deref.get('components', [])
info = self.info_factory.create(info_spec)
servers = self.servers_generator.generate(servers_spec)
paths = self.paths_generator.generate(paths)
components = self.components_factory.create(components_spec)
spec = Spec(info, list(paths), servers=servers, components=components)
spec = Spec(
info, list(paths), servers=list(servers), components=components)
return spec
@property
@ -79,6 +86,11 @@ class SpecFactory(object):
def info_factory(self):
return InfoFactory(self.dereferencer)
@property
@lru_cache()
def servers_generator(self):
return ServersGenerator(self.dereferencer)
@property
@lru_cache()
def paths_generator(self):

View file

@ -1,8 +1,10 @@
"""OpenAPI core wrappers module"""
from six import iteritems
from six.moves.urllib.parse import urljoin
from openapi_core.exceptions import (
OpenAPIMappingError, MissingParameterError, InvalidContentTypeError,
InvalidServerError,
)
SPEC_LOCATION_TO_REQUEST_LOCATION_MAPPING = {
@ -32,13 +34,32 @@ class RequestParameters(dict):
"Unknown parameter location: {0}".format(str(location)))
class RequestParametersFactory(object):
class BaseRequestFactory(object):
def get_operation(self, request, spec):
server = self._get_server(request, spec)
operation_pattern = request.full_url_pattern.replace(
server.default_url, '')
return spec.get_operation(operation_pattern, request.method)
def _get_server(self, request, spec):
for server in spec.servers:
if server.default_url in request.full_url_pattern:
return server
raise InvalidServerError(
"Invalid request server {0}".format(request.full_url_pattern))
class RequestParametersFactory(BaseRequestFactory):
def __init__(self, attr_mapping=SPEC_LOCATION_TO_REQUEST_LOCATION_MAPPING):
self.attr_mapping = attr_mapping
def create(self, request, spec):
operation = spec.get_operation(request.path_pattern, request.method)
operation = self.get_operation(request, spec)
params = RequestParameters()
for param_name, param in iteritems(operation.parameters):
@ -65,10 +86,10 @@ class RequestParametersFactory(object):
return param.unmarshal(raw_value)
class RequestBodyFactory(object):
class RequestBodyFactory(BaseRequestFactory):
def create(self, request, spec):
operation = spec.get_operation(request.path_pattern, request.method)
operation = self.get_operation(request, spec)
try:
media_type = operation.request_body[request.content_type]
@ -78,9 +99,13 @@ class RequestBodyFactory(object):
return media_type.unmarshal(request.data)
def _get_operation(self, request, spec):
return spec.get_operation(request.path_pattern, request.method)
class BaseOpenAPIRequest(object):
host_url = NotImplemented
path = NotImplemented
path_pattern = NotImplemented
method = NotImplemented
@ -94,6 +119,10 @@ class BaseOpenAPIRequest(object):
content_type = NotImplemented
@property
def full_url_pattern(self):
return urljoin(self.host_url, self.path_pattern)
def get_parameters(self, spec):
return RequestParametersFactory().create(self, spec)

View file

@ -5,7 +5,13 @@ info:
license:
name: MIT
servers:
- url: http://petstore.swagger.io/v1
- url: http://petstore.swagger.io/{version}
variables:
version:
enum:
- v1
- v2
default: v1
paths:
/pets:
get:

View file

@ -3,13 +3,14 @@ import pytest
from six import iteritems
from openapi_core.exceptions import (
MissingParameterError, InvalidContentTypeError,
MissingParameterError, InvalidContentTypeError, InvalidServerError,
)
from openapi_core.media_types import MediaType
from openapi_core.operations import Operation
from openapi_core.paths import Path
from openapi_core.request_bodies import RequestBody
from openapi_core.schemas import Schema
from openapi_core.servers import Server, ServerVariable
from openapi_core.shortcuts import create_spec
from openapi_core.wrappers import BaseOpenAPIRequest
@ -17,9 +18,10 @@ from openapi_core.wrappers import BaseOpenAPIRequest
class RequestMock(BaseOpenAPIRequest):
def __init__(
self, method, path, path_pattern=None, args=None, view_args=None,
headers=None, cookies=None, data=None,
self, host_url, method, path, path_pattern=None, args=None,
view_args=None, headers=None, cookies=None, data=None,
content_type='application/json'):
self.host_url = host_url
self.path = path
self.path_pattern = path_pattern or path
self.method = method
@ -44,11 +46,26 @@ class TestPetstore(object):
return create_spec(spec_dict)
def test_spec(self, spec, spec_dict):
url = 'http://petstore.swagger.io/v1'
assert spec.info.title == spec_dict['info']['title']
assert spec.info.version == spec_dict['info']['version']
assert spec.servers == spec_dict['servers']
assert spec.get_server_url() == spec_dict['servers'][0]['url']
assert spec.get_server_url() == url
for idx, server in enumerate(spec.servers):
assert type(server) == Server
server_spec = spec_dict['servers'][idx]
assert server.url == server_spec['url']
assert server.default_url == url
for variable_name, variable in iteritems(server.variables):
assert type(variable) == ServerVariable
assert variable.name == variable_name
variable_spec = server_spec['variables'][variable_name]
assert variable.default == variable_spec['default']
assert variable.enum == variable_spec.get('enum')
for path_name, path in iteritems(spec.paths):
assert type(path) == Path
@ -99,12 +116,17 @@ class TestPetstore(object):
assert type(schema) == Schema
def test_get_pets(self, spec):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': '20',
'ids': ['12', '13'],
}
request = RequestMock('get', '/pets', args=query_params)
request = RequestMock(
host_url, 'get', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
@ -116,17 +138,27 @@ class TestPetstore(object):
}
def test_get_pets_raises_missing_required_param(self, spec):
request = RequestMock('get', '/pets')
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
request = RequestMock(
host_url, 'get', '/pets',
path_pattern=path_pattern,
)
with pytest.raises(MissingParameterError):
request.get_parameters(spec)
def test_get_pets_failed_to_cast(self, spec):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': 'non_integer_value',
}
request = RequestMock('get', '/pets', args=query_params)
request = RequestMock(
host_url, 'get', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
@ -137,11 +169,16 @@ class TestPetstore(object):
}
def test_get_pets_empty_value(self, spec):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': '',
}
request = RequestMock('get', '/pets', args=query_params)
request = RequestMock(
host_url, 'get', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
@ -152,11 +189,16 @@ class TestPetstore(object):
}
def test_get_pets_none_value(self, spec):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': None,
}
request = RequestMock('get', '/pets', args=query_params)
request = RequestMock(
host_url, 'get', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
@ -167,6 +209,8 @@ class TestPetstore(object):
}
def test_post_pets(self, spec, spec_dict):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
pet_name = 'Cat'
pet_tag = 'cats'
pet_street = 'Piekna'
@ -181,7 +225,10 @@ class TestPetstore(object):
}
data = json.dumps(data_json)
request = RequestMock('post', '/pets', data=data)
request = RequestMock(
host_url, 'post', '/pets',
path_pattern=path_pattern, data=data,
)
pet = request.get_body(spec)
@ -196,6 +243,8 @@ class TestPetstore(object):
assert pet.address.city == pet_city
def test_post_pets_raises_invalid_content_type(self, spec):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
data_json = {
'name': 'Cat',
'tag': 'cats',
@ -203,18 +252,39 @@ class TestPetstore(object):
data = json.dumps(data_json)
request = RequestMock(
'post', '/pets', data=data, content_type='text/html')
host_url, 'post', '/pets',
path_pattern=path_pattern, data=data, content_type='text/html',
)
with pytest.raises(InvalidContentTypeError):
request.get_body(spec)
def test_post_pets_raises_invalid_server_error(self, spec):
host_url = 'http://flowerstore.swagger.io/v1'
path_pattern = '/v1/pets'
data_json = {
'name': 'Cat',
'tag': 'cats',
}
data = json.dumps(data_json)
request = RequestMock(
host_url, 'post', '/pets',
path_pattern=path_pattern, data=data, content_type='text/html',
)
with pytest.raises(InvalidServerError):
request.get_body(spec)
def test_get_pet(self, spec):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets/{petId}'
view_args = {
'petId': '1',
}
request = RequestMock(
'get', '/pets/1', path_pattern='/pets/{petId}',
view_args=view_args,
host_url, 'get', '/pets/1',
path_pattern=path_pattern, view_args=view_args,
)
parameters = request.get_parameters(spec)