diff --git a/openapi_core/exceptions.py b/openapi_core/exceptions.py index 6af7d53..9a48425 100644 --- a/openapi_core/exceptions.py +++ b/openapi_core/exceptions.py @@ -15,3 +15,7 @@ class MissingParameterError(OpenAPIMappingError): class InvalidContentTypeError(OpenAPIMappingError): pass + + +class InvalidServerError(OpenAPIMappingError): + pass diff --git a/openapi_core/servers.py b/openapi_core/servers.py new file mode 100644 index 0000000..68e1dc3 --- /dev/null +++ b/openapi_core/servers.py @@ -0,0 +1,76 @@ +from functools import lru_cache + +from six import iteritems + + +class Server(object): + + def __init__(self, url, variables=None): + self.url = url + self.variables = variables and dict(variables) or {} + + @property + def default_url(self): + return self.get_url() + + @property + def default_variables(self): + defaults = {} + for name, variable in iteritems(self.variables): + defaults[name] = variable.default + return defaults + + def get_url(self, **variables): + if not variables: + variables = self.default_variables + return self.url.format(**variables) + + +class ServerVariable(object): + + def __init__(self, name, default, enum=None): + self.name = name + self.default = default + self.enum = enum and list(enum) or [] + + +class ServersGenerator(object): + + def __init__(self, dereferencer): + self.dereferencer = dereferencer + + def generate(self, servers_spec): + servers_deref = self.dereferencer.dereference(servers_spec) + for server_spec in servers_deref: + url = server_spec['url'] + variables_spec = server_spec.get('variables', {}) + + variables = None + if variables_spec: + variables = self.variables_generator.generate(variables_spec) + + yield Server(url, variables=variables) + + @property + @lru_cache() + def variables_generator(self): + return ServerVariablesGenerator(self.dereferencer) + + +class ServerVariablesGenerator(object): + + def __init__(self, dereferencer): + self.dereferencer = dereferencer + + def generate(self, variables_spec): + variables_deref = self.dereferencer.dereference(variables_spec) + + if not variables_deref: + return [Server('/'), ] + + for variable_name, variable_spec in iteritems(variables_deref): + default = variable_spec['default'] + enum = variable_spec.get('enum') + + variable = ServerVariable(variable_name, default, enum=enum) + yield variable_name, variable diff --git a/openapi_core/specs.py b/openapi_core/specs.py index d064184..c49a73c 100644 --- a/openapi_core/specs.py +++ b/openapi_core/specs.py @@ -9,6 +9,7 @@ from openapi_core.components import ComponentsFactory from openapi_core.infos import InfoFactory from openapi_core.paths import PathsGenerator from openapi_core.schemas import SchemaRegistry +from openapi_core.servers import ServersGenerator log = logging.getLogger(__name__) @@ -26,8 +27,12 @@ class Spec(object): def __getitem__(self, path_name): return self.paths[path_name] + @property + def default_url(self): + return self.servers[0].default_url + def get_server_url(self, index=0): - return self.servers[index]['url'] + return self.servers[index].default_url def get_operation(self, path_pattern, http_method): return self.paths[path_pattern].operations[http_method] @@ -59,14 +64,16 @@ class SpecFactory(object): spec_dict_deref = self.dereferencer.dereference(spec_dict) info_spec = spec_dict_deref.get('info', []) - servers = spec_dict_deref.get('servers', []) + servers_spec = spec_dict_deref.get('servers', []) paths = spec_dict_deref.get('paths', []) components_spec = spec_dict_deref.get('components', []) info = self.info_factory.create(info_spec) + servers = self.servers_generator.generate(servers_spec) paths = self.paths_generator.generate(paths) components = self.components_factory.create(components_spec) - spec = Spec(info, list(paths), servers=servers, components=components) + spec = Spec( + info, list(paths), servers=list(servers), components=components) return spec @property @@ -79,6 +86,11 @@ class SpecFactory(object): def info_factory(self): return InfoFactory(self.dereferencer) + @property + @lru_cache() + def servers_generator(self): + return ServersGenerator(self.dereferencer) + @property @lru_cache() def paths_generator(self): diff --git a/openapi_core/wrappers.py b/openapi_core/wrappers.py index cd5945f..3e382d0 100644 --- a/openapi_core/wrappers.py +++ b/openapi_core/wrappers.py @@ -1,8 +1,10 @@ """OpenAPI core wrappers module""" from six import iteritems +from six.moves.urllib.parse import urljoin from openapi_core.exceptions import ( OpenAPIMappingError, MissingParameterError, InvalidContentTypeError, + InvalidServerError, ) SPEC_LOCATION_TO_REQUEST_LOCATION_MAPPING = { @@ -32,13 +34,32 @@ class RequestParameters(dict): "Unknown parameter location: {0}".format(str(location))) -class RequestParametersFactory(object): +class BaseRequestFactory(object): + + def get_operation(self, request, spec): + server = self._get_server(request, spec) + + operation_pattern = request.full_url_pattern.replace( + server.default_url, '') + + return spec.get_operation(operation_pattern, request.method) + + def _get_server(self, request, spec): + for server in spec.servers: + if server.default_url in request.full_url_pattern: + return server + + raise InvalidServerError( + "Invalid request server {0}".format(request.full_url_pattern)) + + +class RequestParametersFactory(BaseRequestFactory): def __init__(self, attr_mapping=SPEC_LOCATION_TO_REQUEST_LOCATION_MAPPING): self.attr_mapping = attr_mapping def create(self, request, spec): - operation = spec.get_operation(request.path_pattern, request.method) + operation = self.get_operation(request, spec) params = RequestParameters() for param_name, param in iteritems(operation.parameters): @@ -65,10 +86,10 @@ class RequestParametersFactory(object): return param.unmarshal(raw_value) -class RequestBodyFactory(object): +class RequestBodyFactory(BaseRequestFactory): def create(self, request, spec): - operation = spec.get_operation(request.path_pattern, request.method) + operation = self.get_operation(request, spec) try: media_type = operation.request_body[request.content_type] @@ -78,9 +99,13 @@ class RequestBodyFactory(object): return media_type.unmarshal(request.data) + def _get_operation(self, request, spec): + return spec.get_operation(request.path_pattern, request.method) + class BaseOpenAPIRequest(object): + host_url = NotImplemented path = NotImplemented path_pattern = NotImplemented method = NotImplemented @@ -94,6 +119,10 @@ class BaseOpenAPIRequest(object): content_type = NotImplemented + @property + def full_url_pattern(self): + return urljoin(self.host_url, self.path_pattern) + def get_parameters(self, spec): return RequestParametersFactory().create(self, spec) diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index 1228873..a11f4a9 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -5,7 +5,13 @@ info: license: name: MIT servers: - - url: http://petstore.swagger.io/v1 + - url: http://petstore.swagger.io/{version} + variables: + version: + enum: + - v1 + - v2 + default: v1 paths: /pets: get: diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 82ca975..ad5ad4b 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -3,13 +3,14 @@ import pytest from six import iteritems from openapi_core.exceptions import ( - MissingParameterError, InvalidContentTypeError, + MissingParameterError, InvalidContentTypeError, InvalidServerError, ) from openapi_core.media_types import MediaType from openapi_core.operations import Operation from openapi_core.paths import Path from openapi_core.request_bodies import RequestBody from openapi_core.schemas import Schema +from openapi_core.servers import Server, ServerVariable from openapi_core.shortcuts import create_spec from openapi_core.wrappers import BaseOpenAPIRequest @@ -17,9 +18,10 @@ from openapi_core.wrappers import BaseOpenAPIRequest class RequestMock(BaseOpenAPIRequest): def __init__( - self, method, path, path_pattern=None, args=None, view_args=None, - headers=None, cookies=None, data=None, + self, host_url, method, path, path_pattern=None, args=None, + view_args=None, headers=None, cookies=None, data=None, content_type='application/json'): + self.host_url = host_url self.path = path self.path_pattern = path_pattern or path self.method = method @@ -44,11 +46,26 @@ class TestPetstore(object): return create_spec(spec_dict) def test_spec(self, spec, spec_dict): + url = 'http://petstore.swagger.io/v1' assert spec.info.title == spec_dict['info']['title'] assert spec.info.version == spec_dict['info']['version'] - assert spec.servers == spec_dict['servers'] - assert spec.get_server_url() == spec_dict['servers'][0]['url'] + assert spec.get_server_url() == url + + for idx, server in enumerate(spec.servers): + assert type(server) == Server + + server_spec = spec_dict['servers'][idx] + assert server.url == server_spec['url'] + assert server.default_url == url + + for variable_name, variable in iteritems(server.variables): + assert type(variable) == ServerVariable + assert variable.name == variable_name + + variable_spec = server_spec['variables'][variable_name] + assert variable.default == variable_spec['default'] + assert variable.enum == variable_spec.get('enum') for path_name, path in iteritems(spec.paths): assert type(path) == Path @@ -99,12 +116,17 @@ class TestPetstore(object): assert type(schema) == Schema def test_get_pets(self, spec): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' query_params = { 'limit': '20', 'ids': ['12', '13'], } - request = RequestMock('get', '/pets', args=query_params) + request = RequestMock( + host_url, 'get', '/pets', + path_pattern=path_pattern, args=query_params, + ) parameters = request.get_parameters(spec) @@ -116,17 +138,27 @@ class TestPetstore(object): } def test_get_pets_raises_missing_required_param(self, spec): - request = RequestMock('get', '/pets') + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' + request = RequestMock( + host_url, 'get', '/pets', + path_pattern=path_pattern, + ) with pytest.raises(MissingParameterError): request.get_parameters(spec) def test_get_pets_failed_to_cast(self, spec): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' query_params = { 'limit': 'non_integer_value', } - request = RequestMock('get', '/pets', args=query_params) + request = RequestMock( + host_url, 'get', '/pets', + path_pattern=path_pattern, args=query_params, + ) parameters = request.get_parameters(spec) @@ -137,11 +169,16 @@ class TestPetstore(object): } def test_get_pets_empty_value(self, spec): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' query_params = { 'limit': '', } - request = RequestMock('get', '/pets', args=query_params) + request = RequestMock( + host_url, 'get', '/pets', + path_pattern=path_pattern, args=query_params, + ) parameters = request.get_parameters(spec) @@ -152,11 +189,16 @@ class TestPetstore(object): } def test_get_pets_none_value(self, spec): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' query_params = { 'limit': None, } - request = RequestMock('get', '/pets', args=query_params) + request = RequestMock( + host_url, 'get', '/pets', + path_pattern=path_pattern, args=query_params, + ) parameters = request.get_parameters(spec) @@ -167,6 +209,8 @@ class TestPetstore(object): } def test_post_pets(self, spec, spec_dict): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' pet_name = 'Cat' pet_tag = 'cats' pet_street = 'Piekna' @@ -181,7 +225,10 @@ class TestPetstore(object): } data = json.dumps(data_json) - request = RequestMock('post', '/pets', data=data) + request = RequestMock( + host_url, 'post', '/pets', + path_pattern=path_pattern, data=data, + ) pet = request.get_body(spec) @@ -196,6 +243,8 @@ class TestPetstore(object): assert pet.address.city == pet_city def test_post_pets_raises_invalid_content_type(self, spec): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' data_json = { 'name': 'Cat', 'tag': 'cats', @@ -203,18 +252,39 @@ class TestPetstore(object): data = json.dumps(data_json) request = RequestMock( - 'post', '/pets', data=data, content_type='text/html') + host_url, 'post', '/pets', + path_pattern=path_pattern, data=data, content_type='text/html', + ) with pytest.raises(InvalidContentTypeError): request.get_body(spec) + def test_post_pets_raises_invalid_server_error(self, spec): + host_url = 'http://flowerstore.swagger.io/v1' + path_pattern = '/v1/pets' + data_json = { + 'name': 'Cat', + 'tag': 'cats', + } + data = json.dumps(data_json) + + request = RequestMock( + host_url, 'post', '/pets', + path_pattern=path_pattern, data=data, content_type='text/html', + ) + + with pytest.raises(InvalidServerError): + request.get_body(spec) + def test_get_pet(self, spec): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets/{petId}' view_args = { 'petId': '1', } request = RequestMock( - 'get', '/pets/1', path_pattern='/pets/{petId}', - view_args=view_args, + host_url, 'get', '/pets/1', + path_pattern=path_pattern, view_args=view_args, ) parameters = request.get_parameters(spec)