Parameter deserializers

This commit is contained in:
Artur Maciag 2017-11-14 16:05:03 +00:00
parent 2d50e920ab
commit fd450e6be6
6 changed files with 170 additions and 67 deletions

View file

@ -2,15 +2,23 @@
import logging
import warnings
from functools import lru_cache
from six import iteritems
from openapi_core.enums import ParameterLocation, ParameterStyle
from openapi_core.enums import ParameterLocation, ParameterStyle, SchemaType
from openapi_core.exceptions import (
EmptyValue, InvalidValueType, InvalidParameterValue,
)
log = logging.getLogger(__name__)
PARAMETER_STYLE_DESERIALIZERS = {
ParameterStyle.FORM: lambda x: x.split(','),
ParameterStyle.SIMPLE: lambda x: x.split(','),
ParameterStyle.SPACE_DELIMITED: lambda x: x.split(' '),
ParameterStyle.PIPE_DELIMITED: lambda x: x.split('|'),
}
class Parameter(object):
"""Represents an OpenAPI operation Parameter."""
@ -32,7 +40,14 @@ class Parameter(object):
)
self.items = items
self.style = ParameterStyle(style or self.default_style)
self.explode = explode or self.default_explode
self.explode = self.default_explode if explode is None else explode
@property
def aslist(self):
return (
self.schema and
self.schema.type in [SchemaType.ARRAY, SchemaType.OBJECT]
)
@property
def default_style(self):
@ -45,6 +60,16 @@ class Parameter(object):
def default_explode(self):
return self.style == ParameterStyle.FORM
def get_dererializer(self):
return PARAMETER_STYLE_DESERIALIZERS[self.style]
def deserialize(self, value):
if not self.aslist or self.explode:
return value
deserializer = self.get_dererializer()
return deserializer(value)
def unmarshal(self, value):
if self.deprecated:
warnings.warn(
@ -60,12 +85,45 @@ class Parameter(object):
if not self.schema:
return value
deserialized = self.deserialize(value)
try:
return self.schema.unmarshal(value)
return self.schema.unmarshal(deserialized)
except InvalidValueType as exc:
raise InvalidParameterValue(str(exc))
class ParameterFactory(object):
def __init__(self, dereferencer, schemas_registry):
self.dereferencer = dereferencer
self.schemas_registry = schemas_registry
def create(self, parameter_spec, parameter_name=None):
parameter_deref = self.dereferencer.dereference(parameter_spec)
parameter_name = parameter_name or parameter_deref['name']
parameter_in = parameter_deref.get('in', 'header')
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
style = parameter_deref.get('style')
explode = parameter_deref.get('explode')
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
return Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
style=style, explode=explode,
)
class ParametersGenerator(object):
def __init__(self, dereferencer, schemas_registry):
@ -73,52 +131,19 @@ class ParametersGenerator(object):
self.schemas_registry = schemas_registry
def generate(self, parameters):
for parameter_name, parameter in iteritems(parameters):
parameter_deref = self.dereferencer.dereference(parameter)
for parameter_name, parameter_spec in iteritems(parameters):
parameter = self.parameter_factory.create(
parameter_spec, parameter_name=parameter_name)
parameter_in = parameter_deref.get('in', 'header')
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
yield (
parameter_name,
Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
),
)
yield (parameter_name, parameter)
def generate_from_list(self, parameters_list):
for parameter in parameters_list:
parameter_deref = self.dereferencer.dereference(parameter)
for parameter_spec in parameters_list:
parameter = self.parameter_factory.create(parameter_spec)
parameter_name = parameter_deref['name']
parameter_in = parameter_deref.get('in', 'header')
yield (parameter.name, parameter)
allow_empty_value = parameter_deref.get('allowEmptyValue')
required = parameter_deref.get('required', False)
style = parameter_deref.get('style')
explode = parameter_deref.get('explode')
schema_spec = parameter_deref.get('schema', None)
schema = None
if schema_spec:
schema, _ = self.schemas_registry.get_or_create(schema_spec)
yield (
parameter_name,
Parameter(
parameter_name, parameter_in,
schema=schema, required=required,
allow_empty_value=allow_empty_value,
style=style, explode=explode,
),
)
@property
@lru_cache()
def parameter_factory(self):
return ParameterFactory(self.dereferencer, self.schemas_registry)

View file

@ -1,7 +1,6 @@
"""OpenAPI core validators module"""
from six import iteritems
from openapi_core.enums import ParameterLocation
from openapi_core.exceptions import (
OpenAPIMappingError, MissingParameter, MissingBody, InvalidResponse,
)
@ -9,6 +8,8 @@ from openapi_core.exceptions import (
class RequestParameters(dict):
valid_locations = ['path', 'query', 'headers', 'cookies']
def __getitem__(self, location):
self.validate_location(location)
@ -19,7 +20,7 @@ class RequestParameters(dict):
@classmethod
def validate_location(cls, location):
if not ParameterLocation.has_value(location):
if location not in cls.valid_locations:
raise OpenAPIMappingError(
"Unknown parameter location: {0}".format(str(location)))
@ -116,12 +117,19 @@ class RequestValidator(object):
return RequestValidationResult(errors, body, parameters)
def _get_raw_value(self, request, param):
location = request.parameters[param.location.value]
try:
return request.parameters[param.location.value][param.name]
raw = request.parameters[param.location.value][param.name]
except KeyError:
raise MissingParameter(
"Missing required `{0}` parameter".format(param.name))
if param.aslist and param.explode:
return location.getlist(param.name)
return raw
def _get_raw_body(self, request):
if not request.body:
raise MissingBody("Missing required request body")

View file

@ -2,6 +2,7 @@
import warnings
from six.moves.urllib.parse import urljoin
from werkzeug.datastructures import ImmutableMultiDict
class BaseOpenAPIRequest(object):
@ -54,9 +55,9 @@ class MockRequest(BaseOpenAPIRequest):
self.parameters = {
'path': view_args or {},
'query': args or {},
'headers': headers or {},
'cookies': cookies or {},
'query': ImmutableMultiDict(args or []),
'header': headers or {},
'cookie': cookies or {},
}
self.body = data or ''

View file

@ -49,6 +49,14 @@ paths:
items:
type: integer
format: int32
- name: tags
in: query
description: Filter pets with tags
schema:
type: array
items:
$ref: "#/components/schemas/Tag"
explode: false
responses:
'200':
description: An paged array of pets

View file

@ -177,6 +177,41 @@ class TestPetstore(object):
assert type(schema) == Schema
def test_get_pets(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': '20',
}
request = MockRequest(
host_url, 'GET', '/pets',
path_pattern=path_pattern, args=query_params,
)
parameters = request.get_parameters(spec)
body = request.get_body(spec)
assert parameters == {
'query': {
'limit': 20,
'page': 1,
'search': '',
}
}
assert body is None
data_json = {
'data': [],
}
data = json.dumps(data_json)
response = MockResponse(data)
response_result = response_validator.validate(request, response)
assert response_result.errors == []
assert response_result.data == data_json
def test_get_pets_ids_param(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
@ -213,13 +248,13 @@ class TestPetstore(object):
assert response_result.errors == []
assert response_result.data == data_json
def test_get_pets_tag_param(self, spec, response_validator):
def test_get_pets_tags_param(self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1'
path_pattern = '/v1/pets'
query_params = {
'limit': '20',
'ids': ['12', '13'],
}
query_params = [
('limit', '20'),
('tags', 'cats,dogs'),
]
request = MockRequest(
host_url, 'GET', '/pets',
@ -234,7 +269,7 @@ class TestPetstore(object):
'limit': 20,
'page': 1,
'search': '',
'ids': [12, 13],
'tags': ['cats', 'dogs'],
}
}
assert body is None

View file

@ -13,8 +13,8 @@ class TestFlaskOpenAPIRequest(object):
server_name = 'localhost'
@pytest.fixture
def environ(self):
return create_environ()
def environ_factory(self):
return create_environ
@pytest.fixture
def map(self):
@ -33,8 +33,9 @@ class TestFlaskOpenAPIRequest(object):
], default_subdomain='www')
@pytest.fixture
def request_factory(self, map, environ):
def create_request(method, path, subdomain=None):
def request_factory(self, map, environ_factory):
def create_request(method, path, subdomain=None, query_string=None):
environ = environ_factory(query_string=query_string)
req = Request(environ)
urls = map.bind_to_environ(
environ, server_name=self.server_name, subdomain=subdomain)
@ -47,14 +48,14 @@ class TestFlaskOpenAPIRequest(object):
def openapi_request(self, request):
return FlaskOpenAPIRequest(request)
def test_simple(self, request_factory, environ, request):
def test_simple(self, request_factory, request):
request = request_factory('GET', '/', subdomain='www')
openapi_request = FlaskOpenAPIRequest(request)
path = {}
query = ImmutableMultiDict([])
headers = EnvironHeaders(environ)
headers = EnvironHeaders(request.environ)
cookies = {}
assert openapi_request.parameters == {
'path': path,
@ -69,14 +70,39 @@ class TestFlaskOpenAPIRequest(object):
assert openapi_request.body == request.data
assert openapi_request.mimetype == request.mimetype
def test_url_rule(self, request_factory, environ, request):
def test_multiple_values(self, request_factory, request):
request = request_factory(
'GET', '/', subdomain='www', query_string='a=b&a=c')
openapi_request = FlaskOpenAPIRequest(request)
path = {}
query = ImmutableMultiDict([
('a', 'b'), ('a', 'c'),
])
headers = EnvironHeaders(request.environ)
cookies = {}
assert openapi_request.parameters == {
'path': path,
'query': query,
'headers': headers,
'cookies': cookies,
}
assert openapi_request.host_url == request.host_url
assert openapi_request.path == request.path
assert openapi_request.method == request.method.lower()
assert openapi_request.path_pattern == request.path
assert openapi_request.body == request.data
assert openapi_request.mimetype == request.mimetype
def test_url_rule(self, request_factory, request):
request = request_factory('GET', '/browse/12/', subdomain='kb')
openapi_request = FlaskOpenAPIRequest(request)
path = {'id': 12}
query = ImmutableMultiDict([])
headers = EnvironHeaders(environ)
headers = EnvironHeaders(request.environ)
cookies = {}
assert openapi_request.parameters == {
'path': path,