Merge pull request #12 from p1c2u/feature/serialized-parameters-support

Serialized parameters support
This commit is contained in:
A 2017-11-14 16:12:32 +00:00 committed by GitHub
commit 64a5045fd8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 323 additions and 76 deletions

48
openapi_core/enums.py Normal file
View file

@ -0,0 +1,48 @@
from enum import Enum
class ParameterLocation(Enum):
PATH = 'path'
QUERY = 'query'
HEADER = 'header'
COOKIE = 'cookie'
@classmethod
def has_value(cls, value):
return (any(value == item.value for item in cls))
class ParameterStyle(Enum):
MATRIX = 'matrix'
LABEL = 'label'
FORM = 'form'
SIMPLE = 'simple'
SPACE_DELIMITED = 'spaceDelimited'
PIPE_DELIMITED = 'pipeDelimited'
DEEP_OBJECT = 'deepObject'
class SchemaType(Enum):
INTEGER = 'integer'
NUMBER = 'number'
STRING = 'string'
BOOLEAN = 'boolean'
ARRAY = 'array'
OBJECT = 'object'
class SchemaFormat(Enum):
NONE = None
INT32 = 'int32'
INT64 = 'int64'
FLOAT = 'float'
DOUBLE = 'double'
BYTE = 'byte'
BINARY = 'binary'
DATE = 'date'
DATETIME = 'date-time'
PASSWORD = 'password'

View file

@ -2,14 +2,23 @@
import logging import logging
import warnings import warnings
from functools import lru_cache
from six import iteritems from six import iteritems
from openapi_core.enums import ParameterLocation, ParameterStyle, SchemaType
from openapi_core.exceptions import ( from openapi_core.exceptions import (
EmptyValue, InvalidValueType, InvalidParameterValue, EmptyValue, InvalidValueType, InvalidParameterValue,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
PARAMETER_STYLE_DESERIALIZERS = {
ParameterStyle.FORM: lambda x: x.split(','),
ParameterStyle.SIMPLE: lambda x: x.split(','),
ParameterStyle.SPACE_DELIMITED: lambda x: x.split(' '),
ParameterStyle.PIPE_DELIMITED: lambda x: x.split('|'),
}
class Parameter(object): class Parameter(object):
"""Represents an OpenAPI operation Parameter.""" """Represents an OpenAPI operation Parameter."""
@ -17,17 +26,49 @@ class Parameter(object):
def __init__( def __init__(
self, name, location, schema=None, required=False, self, name, location, schema=None, required=False,
deprecated=False, allow_empty_value=False, deprecated=False, allow_empty_value=False,
items=None, collection_format=None): items=None, style=None, explode=None):
self.name = name self.name = name
self.location = location self.location = ParameterLocation(location)
self.schema = schema self.schema = schema
self.required = True if self.location == "path" else required self.required = (
True if self.location == ParameterLocation.PATH else required
)
self.deprecated = deprecated self.deprecated = deprecated
self.allow_empty_value = ( self.allow_empty_value = (
allow_empty_value if self.location == "query" else False allow_empty_value if self.location == ParameterLocation.QUERY
else False
) )
self.items = items self.items = items
self.collection_format = collection_format self.style = ParameterStyle(style or self.default_style)
self.explode = self.default_explode if explode is None else explode
@property
def aslist(self):
return (
self.schema and
self.schema.type in [SchemaType.ARRAY, SchemaType.OBJECT]
)
@property
def default_style(self):
simple_locations = [ParameterLocation.PATH, ParameterLocation.HEADER]
return (
'simple' if self.location in simple_locations else "form"
)
@property
def default_explode(self):
return self.style == ParameterStyle.FORM
def get_dererializer(self):
return PARAMETER_STYLE_DESERIALIZERS[self.style]
def deserialize(self, value):
if not self.aslist or self.explode:
return value
deserializer = self.get_dererializer()
return deserializer(value)
def unmarshal(self, value): def unmarshal(self, value):
if self.deprecated: if self.deprecated:
@ -36,7 +77,7 @@ class Parameter(object):
DeprecationWarning, DeprecationWarning,
) )
if (self.location == "query" and value == "" and if (self.location == ParameterLocation.QUERY and value == "" and
not self.allow_empty_value): not self.allow_empty_value):
raise EmptyValue( raise EmptyValue(
"Value of {0} parameter cannot be empty".format(self.name)) "Value of {0} parameter cannot be empty".format(self.name))
@ -44,12 +85,45 @@ class Parameter(object):
if not self.schema: if not self.schema:
return value return value
deserialized = self.deserialize(value)
try: try:
return self.schema.unmarshal(value) return self.schema.unmarshal(deserialized)
except InvalidValueType as exc: except InvalidValueType as exc:
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))
class ParameterFactory(object):
def __init__(self, dereferencer, schemas_registry):
self.dereferencer = dereferencer
self.schemas_registry = schemas_registry
def create(self, parameter_spec, parameter_name=None):
parameter_deref = self.dereferencer.dereference(parameter_spec)
parameter_name = parameter_name or parameter_deref['name']
parameter_in = parameter_deref.get('in', 'header')
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
style = parameter_deref.get('style')
explode = parameter_deref.get('explode')
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
return Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
style=style, explode=explode,
)
class ParametersGenerator(object): class ParametersGenerator(object):
def __init__(self, dereferencer, schemas_registry): def __init__(self, dereferencer, schemas_registry):
@ -57,48 +131,19 @@ class ParametersGenerator(object):
self.schemas_registry = schemas_registry self.schemas_registry = schemas_registry
def generate(self, parameters): def generate(self, parameters):
for parameter_name, parameter in iteritems(parameters): for parameter_name, parameter_spec in iteritems(parameters):
parameter_deref = self.dereferencer.dereference(parameter) parameter = self.parameter_factory.create(
parameter_spec, parameter_name=parameter_name)
parameter_in = parameter_deref.get('in', 'header') yield (parameter_name, parameter)
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
yield (
parameter_name,
Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
),
)
def generate_from_list(self, parameters_list): def generate_from_list(self, parameters_list):
for parameter in parameters_list: for parameter_spec in parameters_list:
parameter_deref = self.dereferencer.dereference(parameter) parameter = self.parameter_factory.create(parameter_spec)
parameter_name = parameter_deref['name'] yield (parameter.name, parameter)
parameter_in = parameter_deref.get('in', 'header')
allow_empty_value = parameter_deref.get('allowEmptyValue') @property
required = parameter_deref.get('required', False) @lru_cache()
def parameter_factory(self):
schema_spec = parameter_deref.get('schema', None) return ParameterFactory(self.dereferencer, self.schemas_registry)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
yield (
parameter_name,
Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
),
)

View file

@ -9,6 +9,7 @@ from functools import lru_cache
from json import loads from json import loads
from six import iteritems from six import iteritems
from openapi_core.enums import SchemaType, SchemaFormat
from openapi_core.exceptions import ( from openapi_core.exceptions import (
InvalidValueType, UndefinedSchemaProperty, MissingProperty, InvalidValue, InvalidValueType, UndefinedSchemaProperty, MissingProperty, InvalidValue,
) )
@ -17,9 +18,9 @@ from openapi_core.models import ModelFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DEFAULT_CAST_CALLABLE_GETTER = { DEFAULT_CAST_CALLABLE_GETTER = {
'integer': int, SchemaType.INTEGER: int,
'number': float, SchemaType.NUMBER: float,
'boolean': lambda x: bool(strtobool(x)), SchemaType.BOOLEAN: lambda x: bool(strtobool(x)),
} }
@ -28,13 +29,13 @@ class Schema(object):
def __init__( def __init__(
self, schema_type, model=None, properties=None, items=None, self, schema_type, model=None, properties=None, items=None,
spec_format=None, required=None, default=None, nullable=False, schema_format=None, required=None, default=None, nullable=False,
enum=None, deprecated=False, all_of=None): enum=None, deprecated=False, all_of=None):
self.type = schema_type self.type = SchemaType(schema_type)
self.model = model self.model = model
self.properties = properties and dict(properties) or {} self.properties = properties and dict(properties) or {}
self.items = items self.items = items
self.format = spec_format self.format = SchemaFormat(schema_format)
self.required = required or [] self.required = required or []
self.default = default self.default = default
self.nullable = nullable self.nullable = nullable
@ -57,8 +58,8 @@ class Schema(object):
def get_cast_mapping(self): def get_cast_mapping(self):
mapping = DEFAULT_CAST_CALLABLE_GETTER.copy() mapping = DEFAULT_CAST_CALLABLE_GETTER.copy()
mapping.update({ mapping.update({
'array': self._unmarshal_collection, SchemaType.ARRAY: self._unmarshal_collection,
'object': self._unmarshal_object, SchemaType.OBJECT: self._unmarshal_object,
}) })
return defaultdict(lambda: lambda x: x, mapping) return defaultdict(lambda: lambda x: x, mapping)
@ -159,6 +160,7 @@ class SchemaFactory(object):
schema_deref = self.dereferencer.dereference(schema_spec) schema_deref = self.dereferencer.dereference(schema_spec)
schema_type = schema_deref['type'] schema_type = schema_deref['type']
schema_format = schema_deref.get('format')
model = schema_deref.get('x-model', None) model = schema_deref.get('x-model', None)
required = schema_deref.get('required', False) required = schema_deref.get('required', False)
default = schema_deref.get('default', None) default = schema_deref.get('default', None)
@ -183,8 +185,8 @@ class SchemaFactory(object):
return Schema( return Schema(
schema_type, model=model, properties=properties, items=items, schema_type, model=model, properties=properties, items=items,
required=required, default=default, nullable=nullable, enum=enum, schema_format=schema_format, required=required, default=default,
deprecated=deprecated, all_of=all_of, nullable=nullable, enum=enum, deprecated=deprecated, all_of=all_of,
) )
@property @property

View file

@ -95,7 +95,7 @@ class RequestValidator(object):
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)
else: else:
parameters[param.location][param_name] = value parameters[param.location.value][param_name] = value
if operation.request_body is not None: if operation.request_body is not None:
try: try:
@ -117,12 +117,19 @@ class RequestValidator(object):
return RequestValidationResult(errors, body, parameters) return RequestValidationResult(errors, body, parameters)
def _get_raw_value(self, request, param): def _get_raw_value(self, request, param):
location = request.parameters[param.location.value]
try: try:
return request.parameters[param.location][param.name] raw = request.parameters[param.location.value][param.name]
except KeyError: except KeyError:
raise MissingParameter( raise MissingParameter(
"Missing required `{0}` parameter".format(param.name)) "Missing required `{0}` parameter".format(param.name))
if param.aslist and param.explode:
return location.getlist(param.name)
return raw
def _get_raw_body(self, request): def _get_raw_body(self, request):
if not request.body: if not request.body:
raise MissingBody("Missing required request body") raise MissingBody("Missing required request body")

View file

@ -2,6 +2,7 @@
import warnings import warnings
from six.moves.urllib.parse import urljoin from six.moves.urllib.parse import urljoin
from werkzeug.datastructures import ImmutableMultiDict
class BaseOpenAPIRequest(object): class BaseOpenAPIRequest(object):
@ -54,9 +55,9 @@ class MockRequest(BaseOpenAPIRequest):
self.parameters = { self.parameters = {
'path': view_args or {}, 'path': view_args or {},
'query': args or {}, 'query': ImmutableMultiDict(args or []),
'headers': headers or {}, 'header': headers or {},
'cookies': cookies or {}, 'cookie': cookies or {},
} }
self.body = data or '' self.body = data or ''

View file

@ -49,6 +49,14 @@ paths:
items: items:
type: integer type: integer
format: int32 format: int32
- name: tags
in: query
description: Filter pets with tags
schema:
type: array
items:
$ref: "#/components/schemas/Tag"
explode: false
responses: responses:
'200': '200':
description: An paged array of pets description: An paged array of pets
@ -119,9 +127,9 @@ components:
Tag: Tag:
type: string type: string
enum: enum:
- Cat - cats
- Dog - dogs
- Bird - birds
Position: Position:
type: integer type: integer
enum: enum:
@ -148,7 +156,7 @@ components:
name: name:
type: string type: string
tag: tag:
type: "#/components/schemas/Tag" $ref: "#/components/schemas/Tag"
address: address:
$ref: "#/components/schemas/Address" $ref: "#/components/schemas/Address"
position: position:

View file

@ -128,7 +128,10 @@ class TestPetstore(object):
continue continue
assert type(parameter.schema) == Schema assert type(parameter.schema) == Schema
assert parameter.schema.type == schema_spec['type'] assert parameter.schema.type.value ==\
schema_spec['type']
assert parameter.schema.format.value ==\
schema_spec.get('format')
assert parameter.schema.required == schema_spec.get( assert parameter.schema.required == schema_spec.get(
'required', []) 'required', [])
@ -160,7 +163,10 @@ class TestPetstore(object):
continue continue
assert type(media_type.schema) == Schema assert type(media_type.schema) == Schema
assert media_type.schema.type == schema_spec['type'] assert media_type.schema.type.value ==\
schema_spec['type']
assert media_type.schema.format.value ==\
schema_spec.get('format')
assert media_type.schema.required == schema_spec.get( assert media_type.schema.required == schema_spec.get(
'required', False) 'required', False)
@ -171,6 +177,41 @@ class TestPetstore(object):
assert type(schema) == Schema assert type(schema) == Schema
def test_get_pets(self, spec, response_validator): def test_get_pets(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': '20',
}
request = MockRequest(
host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
body = request.get_body(spec)
assert parameters == {
'query': {
'limit': 20,
'page': 1,
'search': '',
}
}
assert body is None
data_json = {
'data': [],
}
data = json.dumps(data_json)
response = MockResponse(data)
response_result = response_validator.validate(request, response)
assert response_result.errors == []
assert response_result.data == data_json
def test_get_pets_ids_param(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1' host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets' path_pattern = '/v1/pets'
query_params = { query_params = {
@ -207,6 +248,43 @@ class TestPetstore(object):
assert response_result.errors == [] assert response_result.errors == []
assert response_result.data == data_json assert response_result.data == data_json
def test_get_pets_tags_param(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = [
('limit', '20'),
('tags', 'cats,dogs'),
]
request = MockRequest(
host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
body = request.get_body(spec)
assert parameters == {
'query': {
'limit': 20,
'page': 1,
'search': '',
'tags': ['cats', 'dogs'],
}
}
assert body is None
data_json = {
'data': [],
}
data = json.dumps(data_json)
response = MockResponse(data)
response_result = response_validator.validate(request, response)
assert response_result.errors == []
assert response_result.data == data_json
def test_get_pets_wrong_parameter_type(self, spec): def test_get_pets_wrong_parameter_type(self, spec):
host_url = 'http://petstore.swagger.io/v1' host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets' path_pattern = '/v1/pets'

View file

@ -13,8 +13,8 @@ class TestFlaskOpenAPIRequest(object):
server_name = 'localhost' server_name = 'localhost'
@pytest.fixture @pytest.fixture
def environ(self): def environ_factory(self):
return create_environ() return create_environ
@pytest.fixture @pytest.fixture
def map(self): def map(self):
@ -33,8 +33,9 @@ class TestFlaskOpenAPIRequest(object):
], default_subdomain='www') ], default_subdomain='www')
@pytest.fixture @pytest.fixture
def request_factory(self, map, environ): def request_factory(self, map, environ_factory):
def create_request(method, path, subdomain=None): def create_request(method, path, subdomain=None, query_string=None):
environ = environ_factory(query_string=query_string)
req = Request(environ) req = Request(environ)
urls = map.bind_to_environ( urls = map.bind_to_environ(
environ, server_name=self.server_name, subdomain=subdomain) environ, server_name=self.server_name, subdomain=subdomain)
@ -47,14 +48,14 @@ class TestFlaskOpenAPIRequest(object):
def openapi_request(self, request): def openapi_request(self, request):
return FlaskOpenAPIRequest(request) return FlaskOpenAPIRequest(request)
def test_simple(self, request_factory, environ, request): def test_simple(self, request_factory, request):
request = request_factory('GET', '/', subdomain='www') request = request_factory('GET', '/', subdomain='www')
openapi_request = FlaskOpenAPIRequest(request) openapi_request = FlaskOpenAPIRequest(request)
path = {} path = {}
query = ImmutableMultiDict([]) query = ImmutableMultiDict([])
headers = EnvironHeaders(environ) headers = EnvironHeaders(request.environ)
cookies = {} cookies = {}
assert openapi_request.parameters == { assert openapi_request.parameters == {
'path': path, 'path': path,
@ -69,14 +70,39 @@ class TestFlaskOpenAPIRequest(object):
assert openapi_request.body == request.data assert openapi_request.body == request.data
assert openapi_request.mimetype == request.mimetype assert openapi_request.mimetype == request.mimetype
def test_url_rule(self, request_factory, environ, request): def test_multiple_values(self, request_factory, request):
request = request_factory(
'GET', '/', subdomain='www', query_string='a=b&a=c')
openapi_request = FlaskOpenAPIRequest(request)
path = {}
query = ImmutableMultiDict([
('a', 'b'), ('a', 'c'),
])
headers = EnvironHeaders(request.environ)
cookies = {}
assert openapi_request.parameters == {
'path': path,
'query': query,
'headers': headers,
'cookies': cookies,
}
assert openapi_request.host_url == request.host_url
assert openapi_request.path == request.path
assert openapi_request.method == request.method.lower()
assert openapi_request.path_pattern == request.path
assert openapi_request.body == request.data
assert openapi_request.mimetype == request.mimetype
def test_url_rule(self, request_factory, request):
request = request_factory('GET', '/browse/12/', subdomain='kb') request = request_factory('GET', '/browse/12/', subdomain='kb')
openapi_request = FlaskOpenAPIRequest(request) openapi_request = FlaskOpenAPIRequest(request)
path = {'id': 12} path = {'id': 12}
query = ImmutableMultiDict([]) query = ImmutableMultiDict([])
headers = EnvironHeaders(environ) headers = EnvironHeaders(request.environ)
cookies = {} cookies = {}
assert openapi_request.parameters == { assert openapi_request.parameters == {
'path': path, 'path': path,

View file

@ -1,9 +1,41 @@
import pytest import pytest
from openapi_core.enums import ParameterStyle
from openapi_core.exceptions import EmptyValue from openapi_core.exceptions import EmptyValue
from openapi_core.parameters import Parameter from openapi_core.parameters import Parameter
class TestParameterInit(object):
def test_path(self):
param = Parameter('param', 'path')
assert param.allow_empty_value is False
assert param.style == ParameterStyle.SIMPLE
assert param.explode is False
def test_query(self):
param = Parameter('param', 'query')
assert param.allow_empty_value is False
assert param.style == ParameterStyle.FORM
assert param.explode is True
def test_header(self):
param = Parameter('param', 'header')
assert param.allow_empty_value is False
assert param.style == ParameterStyle.SIMPLE
assert param.explode is False
def test_cookie(self):
param = Parameter('param', 'cookie')
assert param.allow_empty_value is False
assert param.style == ParameterStyle.FORM
assert param.explode is True
class TestParameterUnmarshal(object): class TestParameterUnmarshal(object):
def test_deprecated(self): def test_deprecated(self):