OAS 3.0 validator

This commit is contained in:
p1c2u 2019-09-03 01:38:19 +01:00
parent c4c51637d2
commit b2410e2f3a
19 changed files with 377 additions and 68 deletions

View file

@ -32,17 +32,26 @@ class MediaType(object):
deserializer = self.get_dererializer() deserializer = self.get_dererializer()
return deserializer(value) return deserializer(value)
def unmarshal(self, value, custom_formatters=None): def cast(self, value):
if not self.schema: if not self.schema:
return value return value
try: try:
deserialized = self.deserialize(value) return self.deserialize(value)
except ValueError as exc: except ValueError as exc:
raise InvalidMediaTypeValue(exc) raise InvalidMediaTypeValue(exc)
def unmarshal(self, value, custom_formatters=None, resolver=None):
if not self.schema:
return value
try: try:
unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters) self.schema.validate(value, resolver=resolver)
except OpenAPISchemaError as exc:
raise InvalidMediaTypeValue(exc)
try:
unmarshalled = self.schema.unmarshal(value, custom_formatters=custom_formatters)
except OpenAPISchemaError as exc: except OpenAPISchemaError as exc:
raise InvalidMediaTypeValue(exc) raise InvalidMediaTypeValue(exc)

View file

@ -72,7 +72,7 @@ class Parameter(object):
deserializer = self.get_dererializer() deserializer = self.get_dererializer()
return deserializer(value) return deserializer(value)
def get_value(self, request): def get_raw_value(self, request):
location = request.parameters[self.location.value] location = request.parameters[self.location.value]
if self.name not in location: if self.name not in location:
@ -89,7 +89,7 @@ class Parameter(object):
return location[self.name] return location[self.name]
def unmarshal(self, value, custom_formatters=None): def cast(self, value):
if self.deprecated: if self.deprecated:
warnings.warn( warnings.warn(
"{0} parameter is deprecated".format(self.name), "{0} parameter is deprecated".format(self.name),
@ -109,13 +109,22 @@ class Parameter(object):
raise InvalidParameterValue(self.name, exc) raise InvalidParameterValue(self.name, exc)
try: try:
casted = self.schema.cast(deserialized) return self.schema.cast(deserialized)
except OpenAPISchemaError as exc:
raise InvalidParameterValue(self.name, exc)
def unmarshal(self, value, custom_formatters=None, resolver=None):
if not self.schema:
return value
try:
self.schema.validate(value, resolver=resolver)
except OpenAPISchemaError as exc: except OpenAPISchemaError as exc:
raise InvalidParameterValue(self.name, exc) raise InvalidParameterValue(self.name, exc)
try: try:
unmarshalled = self.schema.unmarshal( unmarshalled = self.schema.unmarshal(
casted, value,
custom_formatters=custom_formatters, custom_formatters=custom_formatters,
strict=True, strict=True,
) )

View file

@ -0,0 +1,9 @@
from jsonschema._format import FormatChecker
from six import binary_type
oas30_format_checker = FormatChecker()
@oas30_format_checker.checks('binary')
def binary(value):
return isinstance(value, binary_type)

View file

@ -0,0 +1,21 @@
from jsonschema._types import (
TypeChecker, is_any, is_array, is_bool, is_integer,
is_object, is_number,
)
from six import text_type, binary_type
def is_string(checker, instance):
return isinstance(instance, (text_type, binary_type))
oas30_type_checker = TypeChecker(
{
u"string": is_string,
u"number": is_number,
u"integer": is_integer,
u"boolean": is_bool,
u"array": is_array,
u"object": is_object,
},
)

View file

@ -0,0 +1,27 @@
from jsonschema.exceptions import ValidationError
def type(validator, data_type, instance, schema):
if instance is None:
return
if not validator.is_type(instance, data_type):
yield ValidationError("%r is not of type %s" % (instance, data_type))
def items(validator, items, instance, schema):
if not validator.is_type(instance, "array"):
return
for index, item in enumerate(instance):
for error in validator.descend(item, items, path=index):
yield error
def nullable(validator, is_nullable, instance, schema):
if instance is None and not is_nullable:
yield ValidationError("None for not nullable")
def not_implemented(validator, value, instance, schema):
pass

View file

@ -50,11 +50,11 @@ class SchemaFactory(object):
all_of = [] all_of = []
if all_of_spec: if all_of_spec:
all_of = map(self.create, all_of_spec) all_of = list(map(self.create, all_of_spec))
one_of = [] one_of = []
if one_of_spec: if one_of_spec:
one_of = map(self.create, one_of_spec) one_of = list(map(self.create, one_of_spec))
items = None items = None
if items_spec: if items_spec:
@ -76,6 +76,7 @@ class SchemaFactory(object):
exclusive_maximum=exclusive_maximum, exclusive_maximum=exclusive_maximum,
exclusive_minimum=exclusive_minimum, exclusive_minimum=exclusive_minimum,
min_properties=min_properties, max_properties=max_properties, min_properties=min_properties, max_properties=max_properties,
_source=schema_deref,
) )
@property @property

View file

@ -9,8 +9,10 @@ import re
import warnings import warnings
from six import iteritems, integer_types, binary_type, text_type from six import iteritems, integer_types, binary_type, text_type
from jsonschema.exceptions import ValidationError
from openapi_core.extensions.models.factories import ModelFactory from openapi_core.extensions.models.factories import ModelFactory
from openapi_core.schema.schemas._format import oas30_format_checker
from openapi_core.schema.schemas.enums import SchemaFormat, SchemaType from openapi_core.schema.schemas.enums import SchemaFormat, SchemaType
from openapi_core.schema.schemas.exceptions import ( from openapi_core.schema.schemas.exceptions import (
InvalidSchemaValue, UndefinedSchemaProperty, MissingSchemaProperty, InvalidSchemaValue, UndefinedSchemaProperty, MissingSchemaProperty,
@ -23,7 +25,7 @@ from openapi_core.schema.schemas.util import (
format_number, format_number,
) )
from openapi_core.schema.schemas.validators import ( from openapi_core.schema.schemas.validators import (
TypeValidator, AttributeValidator, TypeValidator, AttributeValidator, OAS30Validator,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -85,7 +87,7 @@ class Schema(object):
min_length=None, max_length=None, pattern=None, unique_items=False, min_length=None, max_length=None, pattern=None, unique_items=False,
minimum=None, maximum=None, multiple_of=None, minimum=None, maximum=None, multiple_of=None,
exclusive_minimum=False, exclusive_maximum=False, exclusive_minimum=False, exclusive_maximum=False,
min_properties=None, max_properties=None): min_properties=None, max_properties=None, _source=None):
self.type = SchemaType(schema_type) self.type = SchemaType(schema_type)
self.model = model self.model = model
self.properties = properties and dict(properties) or {} self.properties = properties and dict(properties) or {}
@ -119,6 +121,8 @@ class Schema(object):
self._all_required_properties_cache = None self._all_required_properties_cache = None
self._all_optional_properties_cache = None self._all_optional_properties_cache = None
self._source = _source
def __getitem__(self, name): def __getitem__(self, name):
return self.properties[name] return self.properties[name]
@ -214,6 +218,18 @@ class Schema(object):
return defaultdict(lambda: lambda x: x, mapping) return defaultdict(lambda: lambda x: x, mapping)
def get_validator(self, resolver=None):
return OAS30Validator(
self._source, resolver=resolver, format_checker=oas30_format_checker)
def validate(self, value, resolver=None):
validator = self.get_validator(resolver=resolver)
try:
return validator.validate(value)
except ValidationError:
# TODO: pass validation errors
raise InvalidSchemaValue("Value not valid for schema", value, self.type)
def unmarshal(self, value, custom_formatters=None, strict=True): def unmarshal(self, value, custom_formatters=None, strict=True):
"""Unmarshal parameter from the value.""" """Unmarshal parameter from the value."""
if self.deprecated: if self.deprecated:
@ -241,10 +257,7 @@ class Schema(object):
"Value {value} is not of type {type}", value, self.type) "Value {value} is not of type {type}", value, self.type)
except ValueError: except ValueError:
raise InvalidSchemaValue( raise InvalidSchemaValue(
"Failed to cast value {value} to type {type}", value, self.type) "Failed to unmarshal value {value} to type {type}", value, self.type)
if unmarshalled is None and not self.required:
return None
return unmarshalled return unmarshalled
@ -297,8 +310,7 @@ class Schema(object):
return unmarshal_callable(value) return unmarshal_callable(value)
except UnmarshallerStrictTypeError: except UnmarshallerStrictTypeError:
continue continue
# @todo: remove ValueError when validation separated except (OpenAPISchemaError, TypeError):
except (OpenAPISchemaError, TypeError, ValueError):
continue continue
raise NoValidSchema(value) raise NoValidSchema(value)
@ -307,9 +319,6 @@ class Schema(object):
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
raise InvalidSchemaValue("Value {value} is not of type {type}", value, self.type) raise InvalidSchemaValue("Value {value} is not of type {type}", value, self.type)
if self.items is None:
raise UndefinedItemsSchema(self.type)
f = functools.partial( f = functools.partial(
self.items.unmarshal, self.items.unmarshal,
custom_formatters=custom_formatters, strict=strict, custom_formatters=custom_formatters, strict=strict,

View file

@ -1,3 +1,10 @@
from jsonschema import _legacy_validators, _format, _types, _utils, _validators
from jsonschema.validators import create
from openapi_core.schema.schemas import _types as oas_types
from openapi_core.schema.schemas import _validators as oas_validators
class TypeValidator(object): class TypeValidator(object):
def __init__(self, *types, **options): def __init__(self, *types, **options):
@ -24,3 +31,50 @@ class AttributeValidator(object):
return False return False
return True return True
OAS30Validator = create(
meta_schema=_utils.load_schema("draft4"),
validators={
u"multipleOf": _validators.multipleOf,
u"maximum": _legacy_validators.maximum_draft3_draft4,
u"exclusiveMaximum": _validators.exclusiveMaximum,
u"minimum": _legacy_validators.minimum_draft3_draft4,
u"exclusiveMinimum": _validators.exclusiveMinimum,
u"maxLength": _validators.maxLength,
u"minLength": _validators.minLength,
u"pattern": _validators.pattern,
u"maxItems": _validators.maxItems,
u"minItems": _validators.minItems,
u"uniqueItems": _validators.uniqueItems,
u"maxProperties": _validators.maxProperties,
u"minProperties": _validators.minProperties,
u"required": _validators.required,
u"enum": _validators.enum,
# adjusted to OAS
u"type": oas_validators.type,
u"allOf": _validators.allOf,
u"oneOf": _validators.oneOf,
u"anyOf": _validators.anyOf,
u"not": _validators.not_,
u"items": oas_validators.items,
u"properties": _validators.properties,
u"additionalProperties": _validators.additionalProperties,
# TODO: adjust description
u"format": _validators.format,
# TODO: adjust default
u"$ref": _validators.ref,
# fixed OAS fields
u"nullable": oas_validators.nullable,
u"discriminator": oas_validators.not_implemented,
u"readOnly": oas_validators.not_implemented,
u"writeOnly": oas_validators.not_implemented,
u"xml": oas_validators.not_implemented,
u"externalDocs": oas_validators.not_implemented,
u"example": oas_validators.not_implemented,
u"deprecated": oas_validators.not_implemented,
},
type_checker=oas_types.oas30_type_checker,
version="oas30",
id_of=lambda schema: schema.get(u"id", ""),
)

View file

@ -2,6 +2,7 @@
"""OpenAPI core specs factories module""" """OpenAPI core specs factories module"""
from openapi_spec_validator import openapi_v3_spec_validator from openapi_spec_validator import openapi_v3_spec_validator
from openapi_spec_validator.validators import Dereferencer
from openapi_core.compat import lru_cache from openapi_core.compat import lru_cache
from openapi_core.schema.components.factories import ComponentsFactory from openapi_core.schema.components.factories import ComponentsFactory
@ -14,8 +15,8 @@ from openapi_core.schema.specs.models import Spec
class SpecFactory(object): class SpecFactory(object):
def __init__(self, dereferencer, config=None): def __init__(self, spec_resolver, config=None):
self.dereferencer = dereferencer self.spec_resolver = spec_resolver
self.config = config or {} self.config = config or {}
def create(self, spec_dict, spec_url=''): def create(self, spec_dict, spec_url=''):
@ -34,9 +35,16 @@ class SpecFactory(object):
paths = self.paths_generator.generate(paths) paths = self.paths_generator.generate(paths)
components = self.components_factory.create(components_spec) components = self.components_factory.create(components_spec)
spec = Spec( spec = Spec(
info, list(paths), servers=list(servers), components=components) info, list(paths), servers=list(servers), components=components,
_resolver=self.spec_resolver,
)
return spec return spec
@property
@lru_cache()
def dereferencer(self):
return Dereferencer(self.spec_resolver)
@property @property
@lru_cache() @lru_cache()
def schemas_registry(self): def schemas_registry(self):

View file

@ -14,12 +14,14 @@ log = logging.getLogger(__name__)
class Spec(object): class Spec(object):
"""Represents an OpenAPI Specification for a service.""" """Represents an OpenAPI Specification for a service."""
def __init__(self, info, paths, servers=None, components=None): def __init__(self, info, paths, servers=None, components=None, _resolver=None):
self.info = info self.info = info
self.paths = paths and dict(paths) self.paths = paths and dict(paths)
self.servers = servers or [] self.servers = servers or []
self.components = components self.components = components
self._resolver = _resolver
def __getitem__(self, path_pattern): def __getitem__(self, path_pattern):
return self.get_path(path_pattern) return self.get_path(path_pattern)

View file

@ -1,6 +1,5 @@
"""OpenAPI core shortcuts module""" """OpenAPI core shortcuts module"""
from jsonschema.validators import RefResolver from jsonschema.validators import RefResolver
from openapi_spec_validator.validators import Dereferencer
from openapi_spec_validator import default_handlers from openapi_spec_validator import default_handlers
from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError
@ -17,8 +16,7 @@ from openapi_core.validation.response.validators import ResponseValidator
def create_spec(spec_dict, spec_url=''): def create_spec(spec_dict, spec_url=''):
spec_resolver = RefResolver( spec_resolver = RefResolver(
spec_url, spec_dict, handlers=default_handlers) spec_url, spec_dict, handlers=default_handlers)
dereferencer = Dereferencer(spec_resolver) spec_factory = SpecFactory(spec_resolver)
spec_factory = SpecFactory(dereferencer)
return spec_factory.create(spec_dict, spec_url=spec_url) return spec_factory.create(spec_dict, spec_url=spec_url)

View file

@ -58,7 +58,7 @@ class RequestValidator(object):
continue continue
seen.add((param_name, param.location.value)) seen.add((param_name, param.location.value))
try: try:
raw_value = param.get_value(request) raw_value = param.get_raw_value(request)
except MissingParameter: except MissingParameter:
continue continue
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
@ -66,11 +66,20 @@ class RequestValidator(object):
continue continue
try: try:
value = param.unmarshal(raw_value, self.custom_formatters) casted = param.cast(raw_value)
except OpenAPIMappingError as exc:
errors.append(exc)
continue
try:
unmarshalled = param.unmarshal(
casted, self.custom_formatters,
resolver=self.spec._resolver,
)
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)
else: else:
parameters[param.location.value][param_name] = value parameters[param.location.value][param_name] = unmarshalled
return parameters, errors return parameters, errors
@ -92,7 +101,15 @@ class RequestValidator(object):
errors.append(exc) errors.append(exc)
else: else:
try: try:
body = media_type.unmarshal(raw_body, self.custom_formatters) casted = media_type.cast(raw_body)
except OpenAPIMappingError as exc:
errors.append(exc)
else:
try:
body = media_type.unmarshal(
casted, self.custom_formatters,
resolver=self.spec._resolver,
)
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)

View file

@ -61,7 +61,15 @@ class ResponseValidator(object):
errors.append(exc) errors.append(exc)
else: else:
try: try:
data = media_type.unmarshal(raw_data, self.custom_formatters) casted = media_type.cast(raw_data)
except OpenAPIMappingError as exc:
errors.append(exc)
else:
try:
data = media_type.unmarshal(
casted, self.custom_formatters,
resolver=self.spec._resolver,
)
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)

View file

@ -317,7 +317,10 @@ components:
suberror: suberror:
$ref: "#/components/schemas/ExtendedError" $ref: "#/components/schemas/ExtendedError"
additionalProperties: additionalProperties:
type: string oneOf:
- type: string
- type: integer
format: int32
responses: responses:
ErrorResponse: ErrorResponse:
description: unexpected error description: unexpected error

View file

@ -19,9 +19,7 @@ from openapi_core.schema.paths.models import Path
from openapi_core.schema.request_bodies.models import RequestBody from openapi_core.schema.request_bodies.models import RequestBody
from openapi_core.schema.responses.models import Response from openapi_core.schema.responses.models import Response
from openapi_core.schema.schemas.enums import SchemaType from openapi_core.schema.schemas.enums import SchemaType
from openapi_core.schema.schemas.exceptions import ( from openapi_core.schema.schemas.exceptions import InvalidSchemaValue
InvalidSchemaProperty, InvalidSchemaValue,
)
from openapi_core.schema.schemas.models import Schema from openapi_core.schema.schemas.models import Schema
from openapi_core.schema.servers.exceptions import InvalidServer from openapi_core.schema.servers.exceptions import InvalidServer
from openapi_core.schema.servers.models import Server, ServerVariable from openapi_core.schema.servers.models import Server, ServerVariable
@ -41,13 +39,17 @@ class TestPetstore(object):
api_key_bytes_enc = b64encode(api_key_bytes) api_key_bytes_enc = b64encode(api_key_bytes)
return text_type(api_key_bytes_enc, 'utf8') return text_type(api_key_bytes_enc, 'utf8')
@pytest.fixture
def spec_uri(self):
return "file://tests/integration/data/v3.0/petstore.yaml"
@pytest.fixture @pytest.fixture
def spec_dict(self, factory): def spec_dict(self, factory):
return factory.spec_from_file("data/v3.0/petstore.yaml") return factory.spec_from_file("data/v3.0/petstore.yaml")
@pytest.fixture @pytest.fixture
def spec(self, spec_dict): def spec(self, spec_dict, spec_uri):
return create_spec(spec_dict) return create_spec(spec_dict, spec_uri)
@pytest.fixture @pytest.fixture
def request_validator(self, spec): def request_validator(self, spec):
@ -267,6 +269,9 @@ class TestPetstore(object):
{ {
'id': 1, 'id': 1,
'name': 'Cat', 'name': 'Cat',
'ears': {
'healthy': True,
},
} }
], ],
} }
@ -322,16 +327,10 @@ class TestPetstore(object):
assert response_result.errors == [ assert response_result.errors == [
InvalidMediaTypeValue( InvalidMediaTypeValue(
original_exception=InvalidSchemaProperty(
property_name='data',
original_exception=InvalidSchemaProperty(
property_name='name',
original_exception=InvalidSchemaValue( original_exception=InvalidSchemaValue(
msg="Value {value} is not of type {type}", msg='Value not valid for schema',
type=SchemaType.STRING, type=SchemaType.OBJECT,
value={'first_name': 'Cat'}, value=data_json,
),
),
), ),
), ),
] ]
@ -932,6 +931,9 @@ class TestPetstore(object):
'data': { 'data': {
'id': data_id, 'id': data_id,
'name': data_name, 'name': data_name,
'ears': {
'healthy': True,
},
}, },
} }
data = json.dumps(data_json) data = json.dumps(data_json)
@ -1239,7 +1241,6 @@ class TestPetstore(object):
assert response_result.data.rootCause == rootCause assert response_result.data.rootCause == rootCause
assert response_result.data.additionalinfo == additionalinfo assert response_result.data.additionalinfo == additionalinfo
@pytest.mark.xfail(reason='OneOf for string not supported atm')
def test_post_tags_created_invalid_type( def test_post_tags_created_invalid_type(
self, spec, response_validator): self, spec, response_validator):
host_url = 'http://petstore.swagger.io/v1' host_url = 'http://petstore.swagger.io/v1'

View file

@ -421,6 +421,17 @@ class TestResponseValidator(object):
assert result.data is None assert result.data is None
assert result.headers == {} assert result.headers == {}
def test_invalid_media_type(self, validator):
request = MockRequest(self.host_url, 'get', '/v1/pets')
response = MockResponse("abcde")
result = validator.validate(request, response)
assert len(result.errors) == 1
assert type(result.errors[0]) == InvalidMediaTypeValue
assert result.data is None
assert result.headers == {}
def test_invalid_media_type_value(self, validator): def test_invalid_media_type_value(self, validator):
request = MockRequest(self.host_url, 'get', '/v1/pets') request = MockRequest(self.host_url, 'get', '/v1/pets')
response = MockResponse("{}") response = MockResponse("{}")
@ -458,7 +469,10 @@ class TestResponseValidator(object):
'data': [ 'data': [
{ {
'id': 1, 'id': 1,
'name': 'Sparky' 'name': 'Sparky',
'ears': {
'healthy': True,
},
}, },
], ],
} }

View file

@ -0,0 +1,53 @@
import pytest
from openapi_core.schema.media_types.exceptions import InvalidMediaTypeValue
from openapi_core.schema.media_types.models import MediaType
from openapi_core.schema.schemas.models import Schema
class TestMediaTypeCast(object):
def test_empty(self):
media_type = MediaType('application/json')
value = ''
result = media_type.cast(value)
assert result == value
class TestParameterUnmarshal(object):
def test_empty(self):
media_type = MediaType('application/json')
value = ''
result = media_type.unmarshal(value)
assert result == value
def test_schema_type_invalid(self):
schema = Schema('integer', _source={'type': 'integer'})
media_type = MediaType('application/json', schema=schema)
value = 'test'
with pytest.raises(InvalidMediaTypeValue):
media_type.unmarshal(value)
def test_schema_custom_format_invalid(self):
def custom_formatter(value):
raise ValueError
schema = Schema(
'string',
schema_format='custom',
_source={'type': 'string', 'format': 'custom'},
)
custom_formatters = {
'custom': custom_formatter,
}
media_type = MediaType('application/json', schema=schema)
value = 'test'
with pytest.raises(InvalidMediaTypeValue):
media_type.unmarshal(
value, custom_formatters=custom_formatters)

View file

@ -1,8 +1,11 @@
import pytest import pytest
from openapi_core.schema.parameters.exceptions import EmptyParameterValue from openapi_core.schema.parameters.exceptions import (
EmptyParameterValue, InvalidParameterValue,
)
from openapi_core.schema.parameters.enums import ParameterStyle from openapi_core.schema.parameters.enums import ParameterStyle
from openapi_core.schema.parameters.models import Parameter from openapi_core.schema.parameters.models import Parameter
from openapi_core.schema.schemas.models import Schema
class TestParameterInit(object): class TestParameterInit(object):
@ -36,17 +39,35 @@ class TestParameterInit(object):
assert param.explode is True assert param.explode is True
class TestParameterUnmarshal(object): class TestParameterCast(object):
def test_deprecated(self): def test_deprecated(self):
param = Parameter('param', 'query', deprecated=True) param = Parameter('param', 'query', deprecated=True)
value = 'test' value = 'test'
with pytest.warns(DeprecationWarning): with pytest.warns(DeprecationWarning):
result = param.unmarshal(value) result = param.cast(value)
assert result == value assert result == value
def test_query_empty(self):
param = Parameter('param', 'query')
value = ''
with pytest.raises(EmptyParameterValue):
param.cast(value)
def test_query_valid(self):
param = Parameter('param', 'query')
value = 'test'
result = param.cast(value)
assert result == value
class TestParameterUnmarshal(object):
def test_query_valid(self): def test_query_valid(self):
param = Parameter('param', 'query') param = Parameter('param', 'query')
value = 'test' value = 'test'
@ -55,13 +76,6 @@ class TestParameterUnmarshal(object):
assert result == value assert result == value
def test_query_empty(self):
param = Parameter('param', 'query')
value = ''
with pytest.raises(EmptyParameterValue):
param.unmarshal(value)
def test_query_allow_empty_value(self): def test_query_allow_empty_value(self):
param = Parameter('param', 'query', allow_empty_value=True) param = Parameter('param', 'query', allow_empty_value=True)
value = '' value = ''
@ -69,3 +83,28 @@ class TestParameterUnmarshal(object):
result = param.unmarshal(value) result = param.unmarshal(value)
assert result == value assert result == value
def test_query_schema_type_invalid(self):
schema = Schema('integer', _source={'type': 'integer'})
param = Parameter('param', 'query', schema=schema)
value = 'test'
with pytest.raises(InvalidParameterValue):
param.unmarshal(value)
def test_query_schema_custom_format_invalid(self):
def custom_formatter(value):
raise ValueError
schema = Schema(
'string',
schema_format='custom',
_source={'type': 'string', 'format': 'custom'},
)
custom_formatters = {
'custom': custom_formatter,
}
param = Parameter('param', 'query', schema=schema)
value = 'test'
with pytest.raises(InvalidParameterValue):
param.unmarshal(value, custom_formatters=custom_formatters)

View file

@ -42,6 +42,17 @@ class TestSchemaUnmarshal(object):
assert result == value assert result == value
@pytest.mark.parametrize('schema_type', [
'boolean', 'array', 'integer', 'number',
])
def test_non_string_empty_value(self, schema_type):
schema = Schema(schema_type)
value = ''
result = schema.unmarshal(value)
assert result is None
def test_string_valid(self): def test_string_valid(self):
schema = Schema('string') schema = Schema('string')
value = 'test' value = 'test'
@ -121,19 +132,28 @@ class TestSchemaUnmarshal(object):
assert result == datetime.datetime(2018, 1, 2, 0, 0) assert result == datetime.datetime(2018, 1, 2, 0, 0)
@pytest.mark.xfail(reason="No custom formats support atm")
def test_string_format_custom(self): def test_string_format_custom(self):
def custom_formatter(value):
return 'x-custom'
custom_format = 'custom' custom_format = 'custom'
schema = Schema('string', schema_format=custom_format) schema = Schema('string', schema_format=custom_format)
value = 'x' value = 'x'
with mock.patch.dict( result = schema.unmarshal(
Schema.STRING_FORMAT_CAST_CALLABLE_GETTER, value, custom_formatters={custom_format: custom_formatter})
{custom_format: lambda x: x + '-custom'},
):
result = schema.unmarshal(value)
assert result == 'x-custom' assert result == custom_formatter(value)
def test_string_format_custom_value_error(self):
def custom_formatter(value):
raise ValueError
custom_format = 'custom'
schema = Schema('string', schema_format=custom_format)
value = 'x'
with pytest.raises(InvalidSchemaValue):
schema.unmarshal(
value, custom_formatters={custom_format: custom_formatter})
def test_string_format_unknown(self): def test_string_format_unknown(self):
unknown_format = 'unknown' unknown_format = 'unknown'
@ -143,7 +163,6 @@ class TestSchemaUnmarshal(object):
with pytest.raises(OpenAPISchemaError): with pytest.raises(OpenAPISchemaError):
schema.unmarshal(value) schema.unmarshal(value)
@pytest.mark.xfail(reason="No custom formats support atm")
def test_string_format_invalid_value(self): def test_string_format_invalid_value(self):
custom_format = 'custom' custom_format = 'custom'
schema = Schema('string', schema_format=custom_format) schema = Schema('string', schema_format=custom_format)
@ -351,6 +370,14 @@ class TestSchemaObjValidate(object):
assert result is None assert result is None
def test_string_format_custom_missing(self):
custom_format = 'custom'
schema = Schema('string', schema_format=custom_format)
value = 'x'
with pytest.raises(OpenAPISchemaError):
schema.obj_validate(value)
@pytest.mark.parametrize('value', [False, True]) @pytest.mark.parametrize('value', [False, True])
def test_boolean(self, value): def test_boolean(self, value):
schema = Schema('boolean') schema = Schema('boolean')