diff --git a/openapi_core/enums.py b/openapi_core/enums.py new file mode 100644 index 0000000..05ff58f --- /dev/null +++ b/openapi_core/enums.py @@ -0,0 +1,48 @@ +from enum import Enum + + +class ParameterLocation(Enum): + + PATH = 'path' + QUERY = 'query' + HEADER = 'header' + COOKIE = 'cookie' + + @classmethod + def has_value(cls, value): + return (any(value == item.value for item in cls)) + + +class ParameterStyle(Enum): + + MATRIX = 'matrix' + LABEL = 'label' + FORM = 'form' + SIMPLE = 'simple' + SPACE_DELIMITED = 'spaceDelimited' + PIPE_DELIMITED = 'pipeDelimited' + DEEP_OBJECT = 'deepObject' + + +class SchemaType(Enum): + + INTEGER = 'integer' + NUMBER = 'number' + STRING = 'string' + BOOLEAN = 'boolean' + ARRAY = 'array' + OBJECT = 'object' + + +class SchemaFormat(Enum): + + NONE = None + INT32 = 'int32' + INT64 = 'int64' + FLOAT = 'float' + DOUBLE = 'double' + BYTE = 'byte' + BINARY = 'binary' + DATE = 'date' + DATETIME = 'date-time' + PASSWORD = 'password' diff --git a/openapi_core/parameters.py b/openapi_core/parameters.py index ecef11a..cafde80 100644 --- a/openapi_core/parameters.py +++ b/openapi_core/parameters.py @@ -2,14 +2,23 @@ import logging import warnings +from functools import lru_cache from six import iteritems +from openapi_core.enums import ParameterLocation, ParameterStyle, SchemaType from openapi_core.exceptions import ( EmptyValue, InvalidValueType, InvalidParameterValue, ) log = logging.getLogger(__name__) +PARAMETER_STYLE_DESERIALIZERS = { + ParameterStyle.FORM: lambda x: x.split(','), + ParameterStyle.SIMPLE: lambda x: x.split(','), + ParameterStyle.SPACE_DELIMITED: lambda x: x.split(' '), + ParameterStyle.PIPE_DELIMITED: lambda x: x.split('|'), +} + class Parameter(object): """Represents an OpenAPI operation Parameter.""" @@ -17,17 +26,49 @@ class Parameter(object): def __init__( self, name, location, schema=None, required=False, deprecated=False, allow_empty_value=False, - items=None, collection_format=None): + items=None, style=None, explode=None): self.name = name - self.location = location + self.location = ParameterLocation(location) self.schema = schema - self.required = True if self.location == "path" else required + self.required = ( + True if self.location == ParameterLocation.PATH else required + ) self.deprecated = deprecated self.allow_empty_value = ( - allow_empty_value if self.location == "query" else False + allow_empty_value if self.location == ParameterLocation.QUERY + else False ) self.items = items - self.collection_format = collection_format + self.style = ParameterStyle(style or self.default_style) + self.explode = self.default_explode if explode is None else explode + + @property + def aslist(self): + return ( + self.schema and + self.schema.type in [SchemaType.ARRAY, SchemaType.OBJECT] + ) + + @property + def default_style(self): + simple_locations = [ParameterLocation.PATH, ParameterLocation.HEADER] + return ( + 'simple' if self.location in simple_locations else "form" + ) + + @property + def default_explode(self): + return self.style == ParameterStyle.FORM + + def get_dererializer(self): + return PARAMETER_STYLE_DESERIALIZERS[self.style] + + def deserialize(self, value): + if not self.aslist or self.explode: + return value + + deserializer = self.get_dererializer() + return deserializer(value) def unmarshal(self, value): if self.deprecated: @@ -36,7 +77,7 @@ class Parameter(object): DeprecationWarning, ) - if (self.location == "query" and value == "" and + if (self.location == ParameterLocation.QUERY and value == "" and not self.allow_empty_value): raise EmptyValue( "Value of {0} parameter cannot be empty".format(self.name)) @@ -44,12 +85,45 @@ class Parameter(object): if not self.schema: return value + deserialized = self.deserialize(value) + try: - return self.schema.unmarshal(value) + return self.schema.unmarshal(deserialized) except InvalidValueType as exc: raise InvalidParameterValue(str(exc)) +class ParameterFactory(object): + + def __init__(self, dereferencer, schemas_registry): + self.dereferencer = dereferencer + self.schemas_registry = schemas_registry + + def create(self, parameter_spec, parameter_name=None): + parameter_deref = self.dereferencer.dereference(parameter_spec) + + parameter_name = parameter_name or parameter_deref['name'] + parameter_in = parameter_deref.get('in', 'header') + + allow_empty_value = parameter_deref.get('allowEmptyValue') + required = parameter_deref.get('required', False) + + style = parameter_deref.get('style') + explode = parameter_deref.get('explode') + + schema_spec = parameter_deref.get('schema', None) + schema = None + if schema_spec: + schema, _ = self.schemas_registry.get_or_create(schema_spec) + + return Parameter( + parameter_name, parameter_in, + schema=schema, required=required, + allow_empty_value=allow_empty_value, + style=style, explode=explode, + ) + + class ParametersGenerator(object): def __init__(self, dereferencer, schemas_registry): @@ -57,48 +131,19 @@ class ParametersGenerator(object): self.schemas_registry = schemas_registry def generate(self, parameters): - for parameter_name, parameter in iteritems(parameters): - parameter_deref = self.dereferencer.dereference(parameter) + for parameter_name, parameter_spec in iteritems(parameters): + parameter = self.parameter_factory.create( + parameter_spec, parameter_name=parameter_name) - parameter_in = parameter_deref.get('in', 'header') - - allow_empty_value = parameter_deref.get('allowEmptyValue') - required = parameter_deref.get('required', False) - - schema_spec = parameter_deref.get('schema', None) - schema = None - if schema_spec: - schema, _ = self.schemas_registry.get_or_create(schema_spec) - - yield ( - parameter_name, - Parameter( - parameter_name, parameter_in, - schema=schema, required=required, - allow_empty_value=allow_empty_value, - ), - ) + yield (parameter_name, parameter) def generate_from_list(self, parameters_list): - for parameter in parameters_list: - parameter_deref = self.dereferencer.dereference(parameter) + for parameter_spec in parameters_list: + parameter = self.parameter_factory.create(parameter_spec) - parameter_name = parameter_deref['name'] - parameter_in = parameter_deref.get('in', 'header') + yield (parameter.name, parameter) - allow_empty_value = parameter_deref.get('allowEmptyValue') - required = parameter_deref.get('required', False) - - schema_spec = parameter_deref.get('schema', None) - schema = None - if schema_spec: - schema, _ = self.schemas_registry.get_or_create(schema_spec) - - yield ( - parameter_name, - Parameter( - parameter_name, parameter_in, - schema=schema, required=required, - allow_empty_value=allow_empty_value, - ), - ) + @property + @lru_cache() + def parameter_factory(self): + return ParameterFactory(self.dereferencer, self.schemas_registry) diff --git a/openapi_core/schemas.py b/openapi_core/schemas.py index d51640a..534989b 100644 --- a/openapi_core/schemas.py +++ b/openapi_core/schemas.py @@ -9,6 +9,7 @@ from functools import lru_cache from json import loads from six import iteritems +from openapi_core.enums import SchemaType, SchemaFormat from openapi_core.exceptions import ( InvalidValueType, UndefinedSchemaProperty, MissingProperty, InvalidValue, ) @@ -17,9 +18,9 @@ from openapi_core.models import ModelFactory log = logging.getLogger(__name__) DEFAULT_CAST_CALLABLE_GETTER = { - 'integer': int, - 'number': float, - 'boolean': lambda x: bool(strtobool(x)), + SchemaType.INTEGER: int, + SchemaType.NUMBER: float, + SchemaType.BOOLEAN: lambda x: bool(strtobool(x)), } @@ -28,13 +29,13 @@ class Schema(object): def __init__( self, schema_type, model=None, properties=None, items=None, - spec_format=None, required=None, default=None, nullable=False, + schema_format=None, required=None, default=None, nullable=False, enum=None, deprecated=False, all_of=None): - self.type = schema_type + self.type = SchemaType(schema_type) self.model = model self.properties = properties and dict(properties) or {} self.items = items - self.format = spec_format + self.format = SchemaFormat(schema_format) self.required = required or [] self.default = default self.nullable = nullable @@ -57,8 +58,8 @@ class Schema(object): def get_cast_mapping(self): mapping = DEFAULT_CAST_CALLABLE_GETTER.copy() mapping.update({ - 'array': self._unmarshal_collection, - 'object': self._unmarshal_object, + SchemaType.ARRAY: self._unmarshal_collection, + SchemaType.OBJECT: self._unmarshal_object, }) return defaultdict(lambda: lambda x: x, mapping) @@ -159,6 +160,7 @@ class SchemaFactory(object): schema_deref = self.dereferencer.dereference(schema_spec) schema_type = schema_deref['type'] + schema_format = schema_deref.get('format') model = schema_deref.get('x-model', None) required = schema_deref.get('required', False) default = schema_deref.get('default', None) @@ -183,8 +185,8 @@ class SchemaFactory(object): return Schema( schema_type, model=model, properties=properties, items=items, - required=required, default=default, nullable=nullable, enum=enum, - deprecated=deprecated, all_of=all_of, + schema_format=schema_format, required=required, default=default, + nullable=nullable, enum=enum, deprecated=deprecated, all_of=all_of, ) @property diff --git a/openapi_core/validators.py b/openapi_core/validators.py index a376092..ebdfd98 100644 --- a/openapi_core/validators.py +++ b/openapi_core/validators.py @@ -95,7 +95,7 @@ class RequestValidator(object): except OpenAPIMappingError as exc: errors.append(exc) else: - parameters[param.location][param_name] = value + parameters[param.location.value][param_name] = value if operation.request_body is not None: try: @@ -117,12 +117,19 @@ class RequestValidator(object): return RequestValidationResult(errors, body, parameters) def _get_raw_value(self, request, param): + location = request.parameters[param.location.value] + try: - return request.parameters[param.location][param.name] + raw = request.parameters[param.location.value][param.name] except KeyError: raise MissingParameter( "Missing required `{0}` parameter".format(param.name)) + if param.aslist and param.explode: + return location.getlist(param.name) + + return raw + def _get_raw_body(self, request): if not request.body: raise MissingBody("Missing required request body") diff --git a/openapi_core/wrappers.py b/openapi_core/wrappers.py index 7ce7ba8..019e242 100644 --- a/openapi_core/wrappers.py +++ b/openapi_core/wrappers.py @@ -2,6 +2,7 @@ import warnings from six.moves.urllib.parse import urljoin +from werkzeug.datastructures import ImmutableMultiDict class BaseOpenAPIRequest(object): @@ -54,9 +55,9 @@ class MockRequest(BaseOpenAPIRequest): self.parameters = { 'path': view_args or {}, - 'query': args or {}, - 'headers': headers or {}, - 'cookies': cookies or {}, + 'query': ImmutableMultiDict(args or []), + 'header': headers or {}, + 'cookie': cookies or {}, } self.body = data or '' diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index abc7d5b..6dbdf89 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -49,6 +49,14 @@ paths: items: type: integer format: int32 + - name: tags + in: query + description: Filter pets with tags + schema: + type: array + items: + $ref: "#/components/schemas/Tag" + explode: false responses: '200': description: An paged array of pets @@ -119,9 +127,9 @@ components: Tag: type: string enum: - - Cat - - Dog - - Bird + - cats + - dogs + - birds Position: type: integer enum: @@ -148,7 +156,7 @@ components: name: type: string tag: - type: "#/components/schemas/Tag" + $ref: "#/components/schemas/Tag" address: $ref: "#/components/schemas/Address" position: diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 4da3ba1..e53c038 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -128,7 +128,10 @@ class TestPetstore(object): continue assert type(parameter.schema) == Schema - assert parameter.schema.type == schema_spec['type'] + assert parameter.schema.type.value ==\ + schema_spec['type'] + assert parameter.schema.format.value ==\ + schema_spec.get('format') assert parameter.schema.required == schema_spec.get( 'required', []) @@ -160,7 +163,10 @@ class TestPetstore(object): continue assert type(media_type.schema) == Schema - assert media_type.schema.type == schema_spec['type'] + assert media_type.schema.type.value ==\ + schema_spec['type'] + assert media_type.schema.format.value ==\ + schema_spec.get('format') assert media_type.schema.required == schema_spec.get( 'required', False) @@ -171,6 +177,41 @@ class TestPetstore(object): assert type(schema) == Schema def test_get_pets(self, spec, response_validator): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' + query_params = { + 'limit': '20', + } + + request = MockRequest( + host_url, 'GET', '/pets', + path_pattern=path_pattern, args=query_params, + ) + + parameters = request.get_parameters(spec) + body = request.get_body(spec) + + assert parameters == { + 'query': { + 'limit': 20, + 'page': 1, + 'search': '', + } + } + assert body is None + + data_json = { + 'data': [], + } + data = json.dumps(data_json) + response = MockResponse(data) + + response_result = response_validator.validate(request, response) + + assert response_result.errors == [] + assert response_result.data == data_json + + def test_get_pets_ids_param(self, spec, response_validator): host_url = 'http://petstore.swagger.io/v1' path_pattern = '/v1/pets' query_params = { @@ -207,6 +248,43 @@ class TestPetstore(object): assert response_result.errors == [] assert response_result.data == data_json + def test_get_pets_tags_param(self, spec, response_validator): + host_url = 'http://petstore.swagger.io/v1' + path_pattern = '/v1/pets' + query_params = [ + ('limit', '20'), + ('tags', 'cats,dogs'), + ] + + request = MockRequest( + host_url, 'GET', '/pets', + path_pattern=path_pattern, args=query_params, + ) + + parameters = request.get_parameters(spec) + body = request.get_body(spec) + + assert parameters == { + 'query': { + 'limit': 20, + 'page': 1, + 'search': '', + 'tags': ['cats', 'dogs'], + } + } + assert body is None + + data_json = { + 'data': [], + } + data = json.dumps(data_json) + response = MockResponse(data) + + response_result = response_validator.validate(request, response) + + assert response_result.errors == [] + assert response_result.data == data_json + def test_get_pets_wrong_parameter_type(self, spec): host_url = 'http://petstore.swagger.io/v1' path_pattern = '/v1/pets' diff --git a/tests/integration/test_wrappers.py b/tests/integration/test_wrappers.py index 0a4ff01..b33505e 100644 --- a/tests/integration/test_wrappers.py +++ b/tests/integration/test_wrappers.py @@ -13,8 +13,8 @@ class TestFlaskOpenAPIRequest(object): server_name = 'localhost' @pytest.fixture - def environ(self): - return create_environ() + def environ_factory(self): + return create_environ @pytest.fixture def map(self): @@ -33,8 +33,9 @@ class TestFlaskOpenAPIRequest(object): ], default_subdomain='www') @pytest.fixture - def request_factory(self, map, environ): - def create_request(method, path, subdomain=None): + def request_factory(self, map, environ_factory): + def create_request(method, path, subdomain=None, query_string=None): + environ = environ_factory(query_string=query_string) req = Request(environ) urls = map.bind_to_environ( environ, server_name=self.server_name, subdomain=subdomain) @@ -47,14 +48,14 @@ class TestFlaskOpenAPIRequest(object): def openapi_request(self, request): return FlaskOpenAPIRequest(request) - def test_simple(self, request_factory, environ, request): + def test_simple(self, request_factory, request): request = request_factory('GET', '/', subdomain='www') openapi_request = FlaskOpenAPIRequest(request) path = {} query = ImmutableMultiDict([]) - headers = EnvironHeaders(environ) + headers = EnvironHeaders(request.environ) cookies = {} assert openapi_request.parameters == { 'path': path, @@ -69,14 +70,39 @@ class TestFlaskOpenAPIRequest(object): assert openapi_request.body == request.data assert openapi_request.mimetype == request.mimetype - def test_url_rule(self, request_factory, environ, request): + def test_multiple_values(self, request_factory, request): + request = request_factory( + 'GET', '/', subdomain='www', query_string='a=b&a=c') + + openapi_request = FlaskOpenAPIRequest(request) + + path = {} + query = ImmutableMultiDict([ + ('a', 'b'), ('a', 'c'), + ]) + headers = EnvironHeaders(request.environ) + cookies = {} + assert openapi_request.parameters == { + 'path': path, + 'query': query, + 'headers': headers, + 'cookies': cookies, + } + assert openapi_request.host_url == request.host_url + assert openapi_request.path == request.path + assert openapi_request.method == request.method.lower() + assert openapi_request.path_pattern == request.path + assert openapi_request.body == request.data + assert openapi_request.mimetype == request.mimetype + + def test_url_rule(self, request_factory, request): request = request_factory('GET', '/browse/12/', subdomain='kb') openapi_request = FlaskOpenAPIRequest(request) path = {'id': 12} query = ImmutableMultiDict([]) - headers = EnvironHeaders(environ) + headers = EnvironHeaders(request.environ) cookies = {} assert openapi_request.parameters == { 'path': path, diff --git a/tests/unit/test_paramters.py b/tests/unit/test_paramters.py index b3777d9..f4d3f01 100644 --- a/tests/unit/test_paramters.py +++ b/tests/unit/test_paramters.py @@ -1,9 +1,41 @@ import pytest +from openapi_core.enums import ParameterStyle from openapi_core.exceptions import EmptyValue from openapi_core.parameters import Parameter +class TestParameterInit(object): + + def test_path(self): + param = Parameter('param', 'path') + + assert param.allow_empty_value is False + assert param.style == ParameterStyle.SIMPLE + assert param.explode is False + + def test_query(self): + param = Parameter('param', 'query') + + assert param.allow_empty_value is False + assert param.style == ParameterStyle.FORM + assert param.explode is True + + def test_header(self): + param = Parameter('param', 'header') + + assert param.allow_empty_value is False + assert param.style == ParameterStyle.SIMPLE + assert param.explode is False + + def test_cookie(self): + param = Parameter('param', 'cookie') + + assert param.allow_empty_value is False + assert param.style == ParameterStyle.FORM + assert param.explode is True + + class TestParameterUnmarshal(object): def test_deprecated(self):