diff --git a/openapi_core/parameters.py b/openapi_core/parameters.py index 6c12c35..cafde80 100644 --- a/openapi_core/parameters.py +++ b/openapi_core/parameters.py @@ -2,15 +2,23 @@ import logging import warnings +from functools import lru_cache from six import iteritems -from openapi_core.enums import ParameterLocation, ParameterStyle +from openapi_core.enums import ParameterLocation, ParameterStyle, SchemaType from openapi_core.exceptions import ( EmptyValue, InvalidValueType, InvalidParameterValue, ) log = logging.getLogger(__name__) +PARAMETER_STYLE_DESERIALIZERS = { + ParameterStyle.FORM: lambda x: x.split(','), + ParameterStyle.SIMPLE: lambda x: x.split(','), + ParameterStyle.SPACE_DELIMITED: lambda x: x.split(' '), + ParameterStyle.PIPE_DELIMITED: lambda x: x.split('|'), +} + class Parameter(object): """Represents an OpenAPI operation Parameter.""" @@ -32,7 +40,14 @@ class Parameter(object): ) self.items = items self.style = ParameterStyle(style or self.default_style) - self.explode = explode or self.default_explode + self.explode = self.default_explode if explode is None else explode + + @property + def aslist(self): + return ( + self.schema and + self.schema.type in [SchemaType.ARRAY, SchemaType.OBJECT] + ) @property def default_style(self): @@ -45,6 +60,16 @@ class Parameter(object): def default_explode(self): return self.style == ParameterStyle.FORM + def get_dererializer(self): + return PARAMETER_STYLE_DESERIALIZERS[self.style] + + def deserialize(self, value): + if not self.aslist or self.explode: + return value + + deserializer = self.get_dererializer() + return deserializer(value) + def unmarshal(self, value): if self.deprecated: warnings.warn( @@ -60,12 +85,45 @@ class Parameter(object): if not self.schema: return value + deserialized = self.deserialize(value) + try: - return self.schema.unmarshal(value) + return self.schema.unmarshal(deserialized) except InvalidValueType as exc: raise InvalidParameterValue(str(exc)) +class ParameterFactory(object): + + def __init__(self, dereferencer, schemas_registry): + self.dereferencer = dereferencer + self.schemas_registry = schemas_registry + + def create(self, parameter_spec, parameter_name=None): + parameter_deref = self.dereferencer.dereference(parameter_spec) + + parameter_name = parameter_name or parameter_deref['name'] + parameter_in = parameter_deref.get('in', 'header') + + allow_empty_value = parameter_deref.get('allowEmptyValue') + required = parameter_deref.get('required', False) + + style = parameter_deref.get('style') + explode = parameter_deref.get('explode') + + schema_spec = parameter_deref.get('schema', None) + schema = None + if schema_spec: + schema, _ = self.schemas_registry.get_or_create(schema_spec) + + return Parameter( + parameter_name, parameter_in, + schema=schema, required=required, + allow_empty_value=allow_empty_value, + style=style, explode=explode, + ) + + class ParametersGenerator(object): def __init__(self, dereferencer, schemas_registry): @@ -73,52 +131,19 @@ class ParametersGenerator(object): self.schemas_registry = schemas_registry def generate(self, parameters): - for parameter_name, parameter in iteritems(parameters): - parameter_deref = self.dereferencer.dereference(parameter) + for parameter_name, parameter_spec in iteritems(parameters): + parameter = self.parameter_factory.create( + parameter_spec, parameter_name=parameter_name) - parameter_in = parameter_deref.get('in', 'header') - - allow_empty_value = parameter_deref.get('allowEmptyValue') - required = parameter_deref.get('required', False) - - schema_spec = parameter_deref.get('schema', None) - schema = None - if schema_spec: - schema, _ = self.schemas_registry.get_or_create(schema_spec) - - yield ( - parameter_name, - Parameter( - parameter_name, parameter_in, - schema=schema, required=required, - allow_empty_value=allow_empty_value, - ), - ) + yield (parameter_name, parameter) def generate_from_list(self, parameters_list): - for parameter in parameters_list: - parameter_deref = self.dereferencer.dereference(parameter) + for parameter_spec in parameters_list: + parameter = self.parameter_factory.create(parameter_spec) - parameter_name = parameter_deref['name'] - parameter_in = parameter_deref.get('in', 'header') + yield (parameter.name, parameter) - allow_empty_value = parameter_deref.get('allowEmptyValue') - required = parameter_deref.get('required', False) - - style = parameter_deref.get('style') - explode = parameter_deref.get('explode') - - schema_spec = parameter_deref.get('schema', None) - schema = None - if schema_spec: - schema, _ = self.schemas_registry.get_or_create(schema_spec) - - yield ( - parameter_name, - Parameter( - parameter_name, parameter_in, - schema=schema, required=required, - allow_empty_value=allow_empty_value, - style=style, explode=explode, - ), - ) + @property + @lru_cache() + def parameter_factory(self): + return ParameterFactory(self.dereferencer, self.schemas_registry) diff --git a/openapi_core/validators.py b/openapi_core/validators.py index 8d3fcdd..ebdfd98 100644 --- a/openapi_core/validators.py +++ b/openapi_core/validators.py @@ -1,7 +1,6 @@ """OpenAPI core validators module""" from six import iteritems -from openapi_core.enums import ParameterLocation from openapi_core.exceptions import ( OpenAPIMappingError, MissingParameter, MissingBody, InvalidResponse, ) @@ -9,6 +8,8 @@ from openapi_core.exceptions import ( class RequestParameters(dict): + valid_locations = ['path', 'query', 'headers', 'cookies'] + def __getitem__(self, location): self.validate_location(location) @@ -19,7 +20,7 @@ class RequestParameters(dict): @classmethod def validate_location(cls, location): - if not ParameterLocation.has_value(location): + if location not in cls.valid_locations: raise OpenAPIMappingError( "Unknown parameter location: {0}".format(str(location))) @@ -116,12 +117,19 @@ class RequestValidator(object): return RequestValidationResult(errors, body, parameters) def _get_raw_value(self, request, param): + location = request.parameters[param.location.value] + try: - return request.parameters[param.location.value][param.name] + raw = request.parameters[param.location.value][param.name] except KeyError: raise MissingParameter( "Missing required `{0}` parameter".format(param.name)) + if param.aslist and param.explode: + return location.getlist(param.name) + + return raw + def _get_raw_body(self, request): if not request.body: raise MissingBody("Missing required request body") diff --git a/openapi_core/wrappers.py b/openapi_core/wrappers.py index 7ce7ba8..019e242 100644 --- a/openapi_core/wrappers.py +++ b/openapi_core/wrappers.py @@ -2,6 +2,7 @@ import warnings from six.moves.urllib.parse import urljoin +from werkzeug.datastructures import ImmutableMultiDict class BaseOpenAPIRequest(object): @@ -54,9 +55,9 @@ class MockRequest(BaseOpenAPIRequest): self.parameters = { 'path': view_args or {}, - 'query': args or {}, - 'headers': headers or {}, - 'cookies': cookies or {}, + 'query': ImmutableMultiDict(args or []), + 'header': headers or {}, + 'cookie': cookies or {}, } self.body = data or '' diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index 82b0600..6dbdf89 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -49,6 +49,14 @@ paths: items: type: integer format: int32 + - name: tags + in: query + description: Filter pets with tags + schema: + type: array + items: + $ref: "#/components/schemas/Tag" + explode: false responses: '200': description: An paged array of pets diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 910095d..e53c038 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -177,6 +177,41 @@ class TestPetstore(object): assert type(schema) == Schema def test_get_pets(self, spec, response_validator): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' + query_params = { + 'limit': '20', + } + + request = MockRequest( + host_url, 'GET', '/pets', + path_pattern=path_pattern, args=query_params, + ) + + parameters = request.get_parameters(spec) + body = request.get_body(spec) + + assert parameters == { + 'query': { + 'limit': 20, + 'page': 1, + 'search': '', + } + } + assert body is None + + data_json = { + 'data': [], + } + data = json.dumps(data_json) + response = MockResponse(data) + + response_result = response_validator.validate(request, response) + + assert response_result.errors == [] + assert response_result.data == data_json + + def test_get_pets_ids_param(self, spec, response_validator): host_url = 'http://petstore.swagger.io/v1' path_pattern = '/v1/pets' query_params = { @@ -213,13 +248,13 @@ class TestPetstore(object): assert response_result.errors == [] assert response_result.data == data_json - def test_get_pets_tag_param(self, spec, response_validator): + def test_get_pets_tags_param(self, spec, response_validator): host_url = 'http://petstore.swagger.io/v1' path_pattern = '/v1/pets' - query_params = { - 'limit': '20', - 'ids': ['12', '13'], - } + query_params = [ + ('limit', '20'), + ('tags', 'cats,dogs'), + ] request = MockRequest( host_url, 'GET', '/pets', @@ -234,7 +269,7 @@ class TestPetstore(object): 'limit': 20, 'page': 1, 'search': '', - 'ids': [12, 13], + 'tags': ['cats', 'dogs'], } } assert body is None diff --git a/tests/integration/test_wrappers.py b/tests/integration/test_wrappers.py index 0a4ff01..b33505e 100644 --- a/tests/integration/test_wrappers.py +++ b/tests/integration/test_wrappers.py @@ -13,8 +13,8 @@ class TestFlaskOpenAPIRequest(object): server_name = 'localhost' @pytest.fixture - def environ(self): - return create_environ() + def environ_factory(self): + return create_environ @pytest.fixture def map(self): @@ -33,8 +33,9 @@ class TestFlaskOpenAPIRequest(object): ], default_subdomain='www') @pytest.fixture - def request_factory(self, map, environ): - def create_request(method, path, subdomain=None): + def request_factory(self, map, environ_factory): + def create_request(method, path, subdomain=None, query_string=None): + environ = environ_factory(query_string=query_string) req = Request(environ) urls = map.bind_to_environ( environ, server_name=self.server_name, subdomain=subdomain) @@ -47,14 +48,14 @@ class TestFlaskOpenAPIRequest(object): def openapi_request(self, request): return FlaskOpenAPIRequest(request) - def test_simple(self, request_factory, environ, request): + def test_simple(self, request_factory, request): request = request_factory('GET', '/', subdomain='www') openapi_request = FlaskOpenAPIRequest(request) path = {} query = ImmutableMultiDict([]) - headers = EnvironHeaders(environ) + headers = EnvironHeaders(request.environ) cookies = {} assert openapi_request.parameters == { 'path': path, @@ -69,14 +70,39 @@ class TestFlaskOpenAPIRequest(object): assert openapi_request.body == request.data assert openapi_request.mimetype == request.mimetype - def test_url_rule(self, request_factory, environ, request): + def test_multiple_values(self, request_factory, request): + request = request_factory( + 'GET', '/', subdomain='www', query_string='a=b&a=c') + + openapi_request = FlaskOpenAPIRequest(request) + + path = {} + query = ImmutableMultiDict([ + ('a', 'b'), ('a', 'c'), + ]) + headers = EnvironHeaders(request.environ) + cookies = {} + assert openapi_request.parameters == { + 'path': path, + 'query': query, + 'headers': headers, + 'cookies': cookies, + } + assert openapi_request.host_url == request.host_url + assert openapi_request.path == request.path + assert openapi_request.method == request.method.lower() + assert openapi_request.path_pattern == request.path + assert openapi_request.body == request.data + assert openapi_request.mimetype == request.mimetype + + def test_url_rule(self, request_factory, request): request = request_factory('GET', '/browse/12/', subdomain='kb') openapi_request = FlaskOpenAPIRequest(request) path = {'id': 12} query = ImmutableMultiDict([]) - headers = EnvironHeaders(environ) + headers = EnvironHeaders(request.environ) cookies = {} assert openapi_request.parameters == { 'path': path,