Parameter deserializers

This commit is contained in:
Artur Maciag 2017-11-14 16:05:03 +00:00
parent 2d50e920ab
commit fd450e6be6
6 changed files with 170 additions and 67 deletions

View file

@ -2,15 +2,23 @@
import logging import logging
import warnings import warnings
from functools import lru_cache
from six import iteritems from six import iteritems
from openapi_core.enums import ParameterLocation, ParameterStyle from openapi_core.enums import ParameterLocation, ParameterStyle, SchemaType
from openapi_core.exceptions import ( from openapi_core.exceptions import (
EmptyValue, InvalidValueType, InvalidParameterValue, EmptyValue, InvalidValueType, InvalidParameterValue,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
PARAMETER_STYLE_DESERIALIZERS = {
ParameterStyle.FORM: lambda x: x.split(','),
ParameterStyle.SIMPLE: lambda x: x.split(','),
ParameterStyle.SPACE_DELIMITED: lambda x: x.split(' '),
ParameterStyle.PIPE_DELIMITED: lambda x: x.split('|'),
}
class Parameter(object): class Parameter(object):
"""Represents an OpenAPI operation Parameter.""" """Represents an OpenAPI operation Parameter."""
@ -32,7 +40,14 @@ class Parameter(object):
) )
self.items = items self.items = items
self.style = ParameterStyle(style or self.default_style) self.style = ParameterStyle(style or self.default_style)
self.explode = explode or self.default_explode self.explode = self.default_explode if explode is None else explode
@property
def aslist(self):
return (
self.schema and
self.schema.type in [SchemaType.ARRAY, SchemaType.OBJECT]
)
@property @property
def default_style(self): def default_style(self):
@ -45,6 +60,16 @@ class Parameter(object):
def default_explode(self): def default_explode(self):
return self.style == ParameterStyle.FORM return self.style == ParameterStyle.FORM
def get_dererializer(self):
return PARAMETER_STYLE_DESERIALIZERS[self.style]
def deserialize(self, value):
if not self.aslist or self.explode:
return value
deserializer = self.get_dererializer()
return deserializer(value)
def unmarshal(self, value): def unmarshal(self, value):
if self.deprecated: if self.deprecated:
warnings.warn( warnings.warn(
@ -60,12 +85,45 @@ class Parameter(object):
if not self.schema: if not self.schema:
return value return value
deserialized = self.deserialize(value)
try: try:
return self.schema.unmarshal(value) return self.schema.unmarshal(deserialized)
except InvalidValueType as exc: except InvalidValueType as exc:
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))
class ParameterFactory(object):
def __init__(self, dereferencer, schemas_registry):
self.dereferencer = dereferencer
self.schemas_registry = schemas_registry
def create(self, parameter_spec, parameter_name=None):
parameter_deref = self.dereferencer.dereference(parameter_spec)
parameter_name = parameter_name or parameter_deref['name']
parameter_in = parameter_deref.get('in', 'header')
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
style = parameter_deref.get('style')
explode = parameter_deref.get('explode')
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
return Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
style=style, explode=explode,
)
class ParametersGenerator(object): class ParametersGenerator(object):
def __init__(self, dereferencer, schemas_registry): def __init__(self, dereferencer, schemas_registry):
@ -73,52 +131,19 @@ class ParametersGenerator(object):
self.schemas_registry = schemas_registry self.schemas_registry = schemas_registry
def generate(self, parameters): def generate(self, parameters):
for parameter_name, parameter in iteritems(parameters): for parameter_name, parameter_spec in iteritems(parameters):
parameter_deref = self.dereferencer.dereference(parameter) parameter = self.parameter_factory.create(
parameter_spec, parameter_name=parameter_name)
parameter_in = parameter_deref.get('in', 'header') yield (parameter_name, parameter)
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
yield (
parameter_name,
Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
),
)
def generate_from_list(self, parameters_list): def generate_from_list(self, parameters_list):
for parameter in parameters_list: for parameter_spec in parameters_list:
parameter_deref = self.dereferencer.dereference(parameter) parameter = self.parameter_factory.create(parameter_spec)
parameter_name = parameter_deref['name'] yield (parameter.name, parameter)
parameter_in = parameter_deref.get('in', 'header')
allow_empty_value = parameter_deref.get('allowEmptyValue') @property
required = parameter_deref.get('required', False) @lru_cache()
def parameter_factory(self):
style = parameter_deref.get('style') return ParameterFactory(self.dereferencer, self.schemas_registry)
explode = parameter_deref.get('explode')
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
yield (
parameter_name,
Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
style=style, explode=explode,
),
)

View file

@ -1,7 +1,6 @@
"""OpenAPI core validators module""" """OpenAPI core validators module"""
from six import iteritems from six import iteritems
from openapi_core.enums import ParameterLocation
from openapi_core.exceptions import ( from openapi_core.exceptions import (
OpenAPIMappingError, MissingParameter, MissingBody, InvalidResponse, OpenAPIMappingError, MissingParameter, MissingBody, InvalidResponse,
) )
@ -9,6 +8,8 @@ from openapi_core.exceptions import (
class RequestParameters(dict): class RequestParameters(dict):
valid_locations = ['path', 'query', 'headers', 'cookies']
def __getitem__(self, location): def __getitem__(self, location):
self.validate_location(location) self.validate_location(location)
@ -19,7 +20,7 @@ class RequestParameters(dict):
@classmethod @classmethod
def validate_location(cls, location): def validate_location(cls, location):
if not ParameterLocation.has_value(location): if location not in cls.valid_locations:
raise OpenAPIMappingError( raise OpenAPIMappingError(
"Unknown parameter location: {0}".format(str(location))) "Unknown parameter location: {0}".format(str(location)))
@ -116,12 +117,19 @@ class RequestValidator(object):
return RequestValidationResult(errors, body, parameters) return RequestValidationResult(errors, body, parameters)
def _get_raw_value(self, request, param): def _get_raw_value(self, request, param):
location = request.parameters[param.location.value]
try: try:
return request.parameters[param.location.value][param.name] raw = request.parameters[param.location.value][param.name]
except KeyError: except KeyError:
raise MissingParameter( raise MissingParameter(
"Missing required `{0}` parameter".format(param.name)) "Missing required `{0}` parameter".format(param.name))
if param.aslist and param.explode:
return location.getlist(param.name)
return raw
def _get_raw_body(self, request): def _get_raw_body(self, request):
if not request.body: if not request.body:
raise MissingBody("Missing required request body") raise MissingBody("Missing required request body")

View file

@ -2,6 +2,7 @@
import warnings import warnings
from six.moves.urllib.parse import urljoin from six.moves.urllib.parse import urljoin
from werkzeug.datastructures import ImmutableMultiDict
class BaseOpenAPIRequest(object): class BaseOpenAPIRequest(object):
@ -54,9 +55,9 @@ class MockRequest(BaseOpenAPIRequest):
self.parameters = { self.parameters = {
'path': view_args or {}, 'path': view_args or {},
'query': args or {}, 'query': ImmutableMultiDict(args or []),
'headers': headers or {}, 'header': headers or {},
'cookies': cookies or {}, 'cookie': cookies or {},
} }
self.body = data or '' self.body = data or ''

View file

@ -49,6 +49,14 @@ paths:
items: items:
type: integer type: integer
format: int32 format: int32
- name: tags
in: query
description: Filter pets with tags
schema:
type: array
items:
$ref: "#/components/schemas/Tag"
explode: false
responses: responses:
'200': '200':
description: An paged array of pets description: An paged array of pets

View file

@ -177,6 +177,41 @@ class TestPetstore(object):
assert type(schema) == Schema assert type(schema) == Schema
def test_get_pets(self, spec, response_validator): def test_get_pets(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': '20',
}
request = MockRequest(
host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
body = request.get_body(spec)
assert parameters == {
'query': {
'limit': 20,
'page': 1,
'search': '',
}
}
assert body is None
data_json = {
'data': [],
}
data = json.dumps(data_json)
response = MockResponse(data)
response_result = response_validator.validate(request, response)
assert response_result.errors == []
assert response_result.data == data_json
def test_get_pets_ids_param(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1' host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets' path_pattern = '/v1/pets'
query_params = { query_params = {
@ -213,13 +248,13 @@ class TestPetstore(object):
assert response_result.errors == [] assert response_result.errors == []
assert response_result.data == data_json assert response_result.data == data_json
def test_get_pets_tag_param(self, spec, response_validator): def test_get_pets_tags_param(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1' host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets' path_pattern = '/v1/pets'
query_params = { query_params = [
'limit': '20', ('limit', '20'),
'ids': ['12', '13'], ('tags', 'cats,dogs'),
} ]
request = MockRequest( request = MockRequest(
host_url, 'GET', '/pets', host_url, 'GET', '/pets',
@ -234,7 +269,7 @@ class TestPetstore(object):
'limit': 20, 'limit': 20,
'page': 1, 'page': 1,
'search': '', 'search': '',
'ids': [12, 13], 'tags': ['cats', 'dogs'],
} }
} }
assert body is None assert body is None

View file

@ -13,8 +13,8 @@ class TestFlaskOpenAPIRequest(object):
server_name = 'localhost' server_name = 'localhost'
@pytest.fixture @pytest.fixture
def environ(self): def environ_factory(self):
return create_environ() return create_environ
@pytest.fixture @pytest.fixture
def map(self): def map(self):
@ -33,8 +33,9 @@ class TestFlaskOpenAPIRequest(object):
], default_subdomain='www') ], default_subdomain='www')
@pytest.fixture @pytest.fixture
def request_factory(self, map, environ): def request_factory(self, map, environ_factory):
def create_request(method, path, subdomain=None): def create_request(method, path, subdomain=None, query_string=None):
environ = environ_factory(query_string=query_string)
req = Request(environ) req = Request(environ)
urls = map.bind_to_environ( urls = map.bind_to_environ(
environ, server_name=self.server_name, subdomain=subdomain) environ, server_name=self.server_name, subdomain=subdomain)
@ -47,14 +48,14 @@ class TestFlaskOpenAPIRequest(object):
def openapi_request(self, request): def openapi_request(self, request):
return FlaskOpenAPIRequest(request) return FlaskOpenAPIRequest(request)
def test_simple(self, request_factory, environ, request): def test_simple(self, request_factory, request):
request = request_factory('GET', '/', subdomain='www') request = request_factory('GET', '/', subdomain='www')
openapi_request = FlaskOpenAPIRequest(request) openapi_request = FlaskOpenAPIRequest(request)
path = {} path = {}
query = ImmutableMultiDict([]) query = ImmutableMultiDict([])
headers = EnvironHeaders(environ) headers = EnvironHeaders(request.environ)
cookies = {} cookies = {}
assert openapi_request.parameters == { assert openapi_request.parameters == {
'path': path, 'path': path,
@ -69,14 +70,39 @@ class TestFlaskOpenAPIRequest(object):
assert openapi_request.body == request.data assert openapi_request.body == request.data
assert openapi_request.mimetype == request.mimetype assert openapi_request.mimetype == request.mimetype
def test_url_rule(self, request_factory, environ, request): def test_multiple_values(self, request_factory, request):
request = request_factory(
'GET', '/', subdomain='www', query_string='a=b&a=c')
openapi_request = FlaskOpenAPIRequest(request)
path = {}
query = ImmutableMultiDict([
('a', 'b'), ('a', 'c'),
])
headers = EnvironHeaders(request.environ)
cookies = {}
assert openapi_request.parameters == {
'path': path,
'query': query,
'headers': headers,
'cookies': cookies,
}
assert openapi_request.host_url == request.host_url
assert openapi_request.path == request.path
assert openapi_request.method == request.method.lower()
assert openapi_request.path_pattern == request.path
assert openapi_request.body == request.data
assert openapi_request.mimetype == request.mimetype
def test_url_rule(self, request_factory, request):
request = request_factory('GET', '/browse/12/', subdomain='kb') request = request_factory('GET', '/browse/12/', subdomain='kb')
openapi_request = FlaskOpenAPIRequest(request) openapi_request = FlaskOpenAPIRequest(request)
path = {'id': 12} path = {'id': 12}
query = ImmutableMultiDict([]) query = ImmutableMultiDict([])
headers = EnvironHeaders(environ) headers = EnvironHeaders(request.environ)
cookies = {} cookies = {}
assert openapi_request.parameters == { assert openapi_request.parameters == {
'path': path, 'path': path,