diff --git a/openapi_core/schema/media_types/models.py b/openapi_core/schema/media_types/models.py index 5cc278b..dce3737 100644 --- a/openapi_core/schema/media_types/models.py +++ b/openapi_core/schema/media_types/models.py @@ -32,17 +32,26 @@ class MediaType(object): deserializer = self.get_dererializer() return deserializer(value) - def unmarshal(self, value, custom_formatters=None): + def cast(self, value): if not self.schema: return value try: - deserialized = self.deserialize(value) + return self.deserialize(value) except ValueError as exc: raise InvalidMediaTypeValue(exc) + def unmarshal(self, value, custom_formatters=None, resolver=None): + if not self.schema: + return value + try: - unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters) + self.schema.validate(value, resolver=resolver) + except OpenAPISchemaError as exc: + raise InvalidMediaTypeValue(exc) + + try: + unmarshalled = self.schema.unmarshal(value, custom_formatters=custom_formatters) except OpenAPISchemaError as exc: raise InvalidMediaTypeValue(exc) diff --git a/openapi_core/schema/parameters/models.py b/openapi_core/schema/parameters/models.py index f20b4f9..74ef1ae 100644 --- a/openapi_core/schema/parameters/models.py +++ b/openapi_core/schema/parameters/models.py @@ -72,7 +72,7 @@ class Parameter(object): deserializer = self.get_dererializer() return deserializer(value) - def get_value(self, request): + def get_raw_value(self, request): location = request.parameters[self.location.value] if self.name not in location: @@ -89,7 +89,7 @@ class Parameter(object): return location[self.name] - def unmarshal(self, value, custom_formatters=None): + def cast(self, value): if self.deprecated: warnings.warn( "{0} parameter is deprecated".format(self.name), @@ -109,13 +109,22 @@ class Parameter(object): raise InvalidParameterValue(self.name, exc) try: - casted = self.schema.cast(deserialized) + return self.schema.cast(deserialized) + except OpenAPISchemaError as exc: + raise InvalidParameterValue(self.name, exc) + + def unmarshal(self, value, custom_formatters=None, resolver=None): + if not self.schema: + return value + + try: + self.schema.validate(value, resolver=resolver) except OpenAPISchemaError as exc: raise InvalidParameterValue(self.name, exc) try: unmarshalled = self.schema.unmarshal( - casted, + value, custom_formatters=custom_formatters, strict=True, ) diff --git a/openapi_core/schema/schemas/_format.py b/openapi_core/schema/schemas/_format.py new file mode 100644 index 0000000..ec65771 --- /dev/null +++ b/openapi_core/schema/schemas/_format.py @@ -0,0 +1,9 @@ +from jsonschema._format import FormatChecker +from six import binary_type + +oas30_format_checker = FormatChecker() + + +@oas30_format_checker.checks('binary') +def binary(value): + return isinstance(value, binary_type) diff --git a/openapi_core/schema/schemas/_types.py b/openapi_core/schema/schemas/_types.py new file mode 100644 index 0000000..12d1d26 --- /dev/null +++ b/openapi_core/schema/schemas/_types.py @@ -0,0 +1,21 @@ +from jsonschema._types import ( + TypeChecker, is_any, is_array, is_bool, is_integer, + is_object, is_number, +) +from six import text_type, binary_type + + +def is_string(checker, instance): + return isinstance(instance, (text_type, binary_type)) + + +oas30_type_checker = TypeChecker( + { + u"string": is_string, + u"number": is_number, + u"integer": is_integer, + u"boolean": is_bool, + u"array": is_array, + u"object": is_object, + }, +) diff --git a/openapi_core/schema/schemas/_validators.py b/openapi_core/schema/schemas/_validators.py new file mode 100644 index 0000000..19df6d3 --- /dev/null +++ b/openapi_core/schema/schemas/_validators.py @@ -0,0 +1,27 @@ +from jsonschema.exceptions import ValidationError + + +def type(validator, data_type, instance, schema): + if instance is None: + return + + if not validator.is_type(instance, data_type): + yield ValidationError("%r is not of type %s" % (instance, data_type)) + + +def items(validator, items, instance, schema): + if not validator.is_type(instance, "array"): + return + + for index, item in enumerate(instance): + for error in validator.descend(item, items, path=index): + yield error + + +def nullable(validator, is_nullable, instance, schema): + if instance is None and not is_nullable: + yield ValidationError("None for not nullable") + + +def not_implemented(validator, value, instance, schema): + pass diff --git a/openapi_core/schema/schemas/factories.py b/openapi_core/schema/schemas/factories.py index 136fbd0..96229d3 100644 --- a/openapi_core/schema/schemas/factories.py +++ b/openapi_core/schema/schemas/factories.py @@ -50,11 +50,11 @@ class SchemaFactory(object): all_of = [] if all_of_spec: - all_of = map(self.create, all_of_spec) + all_of = list(map(self.create, all_of_spec)) one_of = [] if one_of_spec: - one_of = map(self.create, one_of_spec) + one_of = list(map(self.create, one_of_spec)) items = None if items_spec: @@ -76,6 +76,7 @@ class SchemaFactory(object): exclusive_maximum=exclusive_maximum, exclusive_minimum=exclusive_minimum, min_properties=min_properties, max_properties=max_properties, + _source=schema_deref, ) @property diff --git a/openapi_core/schema/schemas/models.py b/openapi_core/schema/schemas/models.py index 17fa44f..aa38e3d 100644 --- a/openapi_core/schema/schemas/models.py +++ b/openapi_core/schema/schemas/models.py @@ -9,8 +9,10 @@ import re import warnings from six import iteritems, integer_types, binary_type, text_type +from jsonschema.exceptions import ValidationError from openapi_core.extensions.models.factories import ModelFactory +from openapi_core.schema.schemas._format import oas30_format_checker from openapi_core.schema.schemas.enums import SchemaFormat, SchemaType from openapi_core.schema.schemas.exceptions import ( InvalidSchemaValue, UndefinedSchemaProperty, MissingSchemaProperty, @@ -23,7 +25,7 @@ from openapi_core.schema.schemas.util import ( format_number, ) from openapi_core.schema.schemas.validators import ( - TypeValidator, AttributeValidator, + TypeValidator, AttributeValidator, OAS30Validator, ) log = logging.getLogger(__name__) @@ -85,7 +87,7 @@ class Schema(object): min_length=None, max_length=None, pattern=None, unique_items=False, minimum=None, maximum=None, multiple_of=None, exclusive_minimum=False, exclusive_maximum=False, - min_properties=None, max_properties=None): + min_properties=None, max_properties=None, _source=None): self.type = SchemaType(schema_type) self.model = model self.properties = properties and dict(properties) or {} @@ -119,6 +121,8 @@ class Schema(object): self._all_required_properties_cache = None self._all_optional_properties_cache = None + self._source = _source + def __getitem__(self, name): return self.properties[name] @@ -214,6 +218,18 @@ class Schema(object): return defaultdict(lambda: lambda x: x, mapping) + def get_validator(self, resolver=None): + return OAS30Validator( + self._source, resolver=resolver, format_checker=oas30_format_checker) + + def validate(self, value, resolver=None): + validator = self.get_validator(resolver=resolver) + try: + return validator.validate(value) + except ValidationError: + # TODO: pass validation errors + raise InvalidSchemaValue("Value not valid for schema", value, self.type) + def unmarshal(self, value, custom_formatters=None, strict=True): """Unmarshal parameter from the value.""" if self.deprecated: @@ -241,10 +257,7 @@ class Schema(object): "Value {value} is not of type {type}", value, self.type) except ValueError: raise InvalidSchemaValue( - "Failed to cast value {value} to type {type}", value, self.type) - - if unmarshalled is None and not self.required: - return None + "Failed to unmarshal value {value} to type {type}", value, self.type) return unmarshalled @@ -297,8 +310,7 @@ class Schema(object): return unmarshal_callable(value) except UnmarshallerStrictTypeError: continue - # @todo: remove ValueError when validation separated - except (OpenAPISchemaError, TypeError, ValueError): + except (OpenAPISchemaError, TypeError): continue raise NoValidSchema(value) @@ -307,9 +319,6 @@ class Schema(object): if not isinstance(value, (list, tuple)): raise InvalidSchemaValue("Value {value} is not of type {type}", value, self.type) - if self.items is None: - raise UndefinedItemsSchema(self.type) - f = functools.partial( self.items.unmarshal, custom_formatters=custom_formatters, strict=strict, diff --git a/openapi_core/schema/schemas/validators.py b/openapi_core/schema/schemas/validators.py index 2d9ec49..7d7b117 100644 --- a/openapi_core/schema/schemas/validators.py +++ b/openapi_core/schema/schemas/validators.py @@ -1,3 +1,10 @@ +from jsonschema import _legacy_validators, _format, _types, _utils, _validators +from jsonschema.validators import create + +from openapi_core.schema.schemas import _types as oas_types +from openapi_core.schema.schemas import _validators as oas_validators + + class TypeValidator(object): def __init__(self, *types, **options): @@ -24,3 +31,50 @@ class AttributeValidator(object): return False return True + + +OAS30Validator = create( + meta_schema=_utils.load_schema("draft4"), + validators={ + u"multipleOf": _validators.multipleOf, + u"maximum": _legacy_validators.maximum_draft3_draft4, + u"exclusiveMaximum": _validators.exclusiveMaximum, + u"minimum": _legacy_validators.minimum_draft3_draft4, + u"exclusiveMinimum": _validators.exclusiveMinimum, + u"maxLength": _validators.maxLength, + u"minLength": _validators.minLength, + u"pattern": _validators.pattern, + u"maxItems": _validators.maxItems, + u"minItems": _validators.minItems, + u"uniqueItems": _validators.uniqueItems, + u"maxProperties": _validators.maxProperties, + u"minProperties": _validators.minProperties, + u"required": _validators.required, + u"enum": _validators.enum, + # adjusted to OAS + u"type": oas_validators.type, + u"allOf": _validators.allOf, + u"oneOf": _validators.oneOf, + u"anyOf": _validators.anyOf, + u"not": _validators.not_, + u"items": oas_validators.items, + u"properties": _validators.properties, + u"additionalProperties": _validators.additionalProperties, + # TODO: adjust description + u"format": _validators.format, + # TODO: adjust default + u"$ref": _validators.ref, + # fixed OAS fields + u"nullable": oas_validators.nullable, + u"discriminator": oas_validators.not_implemented, + u"readOnly": oas_validators.not_implemented, + u"writeOnly": oas_validators.not_implemented, + u"xml": oas_validators.not_implemented, + u"externalDocs": oas_validators.not_implemented, + u"example": oas_validators.not_implemented, + u"deprecated": oas_validators.not_implemented, + }, + type_checker=oas_types.oas30_type_checker, + version="oas30", + id_of=lambda schema: schema.get(u"id", ""), +) diff --git a/openapi_core/schema/specs/factories.py b/openapi_core/schema/specs/factories.py index 16f736d..0d31dd6 100644 --- a/openapi_core/schema/specs/factories.py +++ b/openapi_core/schema/specs/factories.py @@ -2,6 +2,7 @@ """OpenAPI core specs factories module""" from openapi_spec_validator import openapi_v3_spec_validator +from openapi_spec_validator.validators import Dereferencer from openapi_core.compat import lru_cache from openapi_core.schema.components.factories import ComponentsFactory @@ -14,8 +15,8 @@ from openapi_core.schema.specs.models import Spec class SpecFactory(object): - def __init__(self, dereferencer, config=None): - self.dereferencer = dereferencer + def __init__(self, spec_resolver, config=None): + self.spec_resolver = spec_resolver self.config = config or {} def create(self, spec_dict, spec_url=''): @@ -34,9 +35,16 @@ class SpecFactory(object): paths = self.paths_generator.generate(paths) components = self.components_factory.create(components_spec) spec = Spec( - info, list(paths), servers=list(servers), components=components) + info, list(paths), servers=list(servers), components=components, + _resolver=self.spec_resolver, + ) return spec + @property + @lru_cache() + def dereferencer(self): + return Dereferencer(self.spec_resolver) + @property @lru_cache() def schemas_registry(self): diff --git a/openapi_core/schema/specs/models.py b/openapi_core/schema/specs/models.py index 7e7c4e1..f4a115e 100644 --- a/openapi_core/schema/specs/models.py +++ b/openapi_core/schema/specs/models.py @@ -14,12 +14,14 @@ log = logging.getLogger(__name__) class Spec(object): """Represents an OpenAPI Specification for a service.""" - def __init__(self, info, paths, servers=None, components=None): + def __init__(self, info, paths, servers=None, components=None, _resolver=None): self.info = info self.paths = paths and dict(paths) self.servers = servers or [] self.components = components + self._resolver = _resolver + def __getitem__(self, path_pattern): return self.get_path(path_pattern) diff --git a/openapi_core/shortcuts.py b/openapi_core/shortcuts.py index bcf4d31..02df1c1 100644 --- a/openapi_core/shortcuts.py +++ b/openapi_core/shortcuts.py @@ -1,6 +1,5 @@ """OpenAPI core shortcuts module""" from jsonschema.validators import RefResolver -from openapi_spec_validator.validators import Dereferencer from openapi_spec_validator import default_handlers from openapi_core.schema.media_types.exceptions import OpenAPIMediaTypeError @@ -17,8 +16,7 @@ from openapi_core.validation.response.validators import ResponseValidator def create_spec(spec_dict, spec_url=''): spec_resolver = RefResolver( spec_url, spec_dict, handlers=default_handlers) - dereferencer = Dereferencer(spec_resolver) - spec_factory = SpecFactory(dereferencer) + spec_factory = SpecFactory(spec_resolver) return spec_factory.create(spec_dict, spec_url=spec_url) diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 734b589..ba4851e 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -58,7 +58,7 @@ class RequestValidator(object): continue seen.add((param_name, param.location.value)) try: - raw_value = param.get_value(request) + raw_value = param.get_raw_value(request) except MissingParameter: continue except OpenAPIMappingError as exc: @@ -66,11 +66,20 @@ class RequestValidator(object): continue try: - value = param.unmarshal(raw_value, self.custom_formatters) + casted = param.cast(raw_value) + except OpenAPIMappingError as exc: + errors.append(exc) + continue + + try: + unmarshalled = param.unmarshal( + casted, self.custom_formatters, + resolver=self.spec._resolver, + ) except OpenAPIMappingError as exc: errors.append(exc) else: - parameters[param.location.value][param_name] = value + parameters[param.location.value][param_name] = unmarshalled return parameters, errors @@ -92,8 +101,16 @@ class RequestValidator(object): errors.append(exc) else: try: - body = media_type.unmarshal(raw_body, self.custom_formatters) + casted = media_type.cast(raw_body) except OpenAPIMappingError as exc: errors.append(exc) + else: + try: + body = media_type.unmarshal( + casted, self.custom_formatters, + resolver=self.spec._resolver, + ) + except OpenAPIMappingError as exc: + errors.append(exc) return body, errors diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index 4f9696b..2ed923a 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -61,9 +61,17 @@ class ResponseValidator(object): errors.append(exc) else: try: - data = media_type.unmarshal(raw_data, self.custom_formatters) + casted = media_type.cast(raw_data) except OpenAPIMappingError as exc: errors.append(exc) + else: + try: + data = media_type.unmarshal( + casted, self.custom_formatters, + resolver=self.spec._resolver, + ) + except OpenAPIMappingError as exc: + errors.append(exc) return data, errors diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index efd817d..5f42339 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -317,7 +317,10 @@ components: suberror: $ref: "#/components/schemas/ExtendedError" additionalProperties: - type: string + oneOf: + - type: string + - type: integer + format: int32 responses: ErrorResponse: description: unexpected error diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 4c92c91..5e57c5e 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -19,9 +19,7 @@ from openapi_core.schema.paths.models import Path from openapi_core.schema.request_bodies.models import RequestBody from openapi_core.schema.responses.models import Response from openapi_core.schema.schemas.enums import SchemaType -from openapi_core.schema.schemas.exceptions import ( - InvalidSchemaProperty, InvalidSchemaValue, -) +from openapi_core.schema.schemas.exceptions import InvalidSchemaValue from openapi_core.schema.schemas.models import Schema from openapi_core.schema.servers.exceptions import InvalidServer from openapi_core.schema.servers.models import Server, ServerVariable @@ -41,13 +39,17 @@ class TestPetstore(object): api_key_bytes_enc = b64encode(api_key_bytes) return text_type(api_key_bytes_enc, 'utf8') + @pytest.fixture + def spec_uri(self): + return "file://tests/integration/data/v3.0/petstore.yaml" + @pytest.fixture def spec_dict(self, factory): return factory.spec_from_file("data/v3.0/petstore.yaml") @pytest.fixture - def spec(self, spec_dict): - return create_spec(spec_dict) + def spec(self, spec_dict, spec_uri): + return create_spec(spec_dict, spec_uri) @pytest.fixture def request_validator(self, spec): @@ -267,6 +269,9 @@ class TestPetstore(object): { 'id': 1, 'name': 'Cat', + 'ears': { + 'healthy': True, + }, } ], } @@ -322,16 +327,10 @@ class TestPetstore(object): assert response_result.errors == [ InvalidMediaTypeValue( - original_exception=InvalidSchemaProperty( - property_name='data', - original_exception=InvalidSchemaProperty( - property_name='name', - original_exception=InvalidSchemaValue( - msg="Value {value} is not of type {type}", - type=SchemaType.STRING, - value={'first_name': 'Cat'}, - ), - ), + original_exception=InvalidSchemaValue( + msg='Value not valid for schema', + type=SchemaType.OBJECT, + value=data_json, ), ), ] @@ -932,6 +931,9 @@ class TestPetstore(object): 'data': { 'id': data_id, 'name': data_name, + 'ears': { + 'healthy': True, + }, }, } data = json.dumps(data_json) @@ -1239,7 +1241,6 @@ class TestPetstore(object): assert response_result.data.rootCause == rootCause assert response_result.data.additionalinfo == additionalinfo - @pytest.mark.xfail(reason='OneOf for string not supported atm') def test_post_tags_created_invalid_type( self, spec, response_validator): host_url = 'http://petstore.swagger.io/v1' diff --git a/tests/integration/test_validators.py b/tests/integration/test_validators.py index 2f26a7b..2d3af8f 100644 --- a/tests/integration/test_validators.py +++ b/tests/integration/test_validators.py @@ -421,6 +421,17 @@ class TestResponseValidator(object): assert result.data is None assert result.headers == {} + def test_invalid_media_type(self, validator): + request = MockRequest(self.host_url, 'get', '/v1/pets') + response = MockResponse("abcde") + + result = validator.validate(request, response) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == InvalidMediaTypeValue + assert result.data is None + assert result.headers == {} + def test_invalid_media_type_value(self, validator): request = MockRequest(self.host_url, 'get', '/v1/pets') response = MockResponse("{}") @@ -458,7 +469,10 @@ class TestResponseValidator(object): 'data': [ { 'id': 1, - 'name': 'Sparky' + 'name': 'Sparky', + 'ears': { + 'healthy': True, + }, }, ], } diff --git a/tests/unit/schema/test_media_types.py b/tests/unit/schema/test_media_types.py new file mode 100644 index 0000000..2d266ac --- /dev/null +++ b/tests/unit/schema/test_media_types.py @@ -0,0 +1,53 @@ +import pytest + +from openapi_core.schema.media_types.exceptions import InvalidMediaTypeValue +from openapi_core.schema.media_types.models import MediaType +from openapi_core.schema.schemas.models import Schema + + +class TestMediaTypeCast(object): + + def test_empty(self): + media_type = MediaType('application/json') + value = '' + + result = media_type.cast(value) + + assert result == value + + +class TestParameterUnmarshal(object): + + def test_empty(self): + media_type = MediaType('application/json') + value = '' + + result = media_type.unmarshal(value) + + assert result == value + + def test_schema_type_invalid(self): + schema = Schema('integer', _source={'type': 'integer'}) + media_type = MediaType('application/json', schema=schema) + value = 'test' + + with pytest.raises(InvalidMediaTypeValue): + media_type.unmarshal(value) + + def test_schema_custom_format_invalid(self): + def custom_formatter(value): + raise ValueError + schema = Schema( + 'string', + schema_format='custom', + _source={'type': 'string', 'format': 'custom'}, + ) + custom_formatters = { + 'custom': custom_formatter, + } + media_type = MediaType('application/json', schema=schema) + value = 'test' + + with pytest.raises(InvalidMediaTypeValue): + media_type.unmarshal( + value, custom_formatters=custom_formatters) diff --git a/tests/unit/schema/test_parameters.py b/tests/unit/schema/test_parameters.py index 952e956..2755dcb 100644 --- a/tests/unit/schema/test_parameters.py +++ b/tests/unit/schema/test_parameters.py @@ -1,8 +1,11 @@ import pytest -from openapi_core.schema.parameters.exceptions import EmptyParameterValue +from openapi_core.schema.parameters.exceptions import ( + EmptyParameterValue, InvalidParameterValue, +) from openapi_core.schema.parameters.enums import ParameterStyle from openapi_core.schema.parameters.models import Parameter +from openapi_core.schema.schemas.models import Schema class TestParameterInit(object): @@ -36,17 +39,35 @@ class TestParameterInit(object): assert param.explode is True -class TestParameterUnmarshal(object): +class TestParameterCast(object): def test_deprecated(self): param = Parameter('param', 'query', deprecated=True) value = 'test' with pytest.warns(DeprecationWarning): - result = param.unmarshal(value) + result = param.cast(value) assert result == value + def test_query_empty(self): + param = Parameter('param', 'query') + value = '' + + with pytest.raises(EmptyParameterValue): + param.cast(value) + + def test_query_valid(self): + param = Parameter('param', 'query') + value = 'test' + + result = param.cast(value) + + assert result == value + + +class TestParameterUnmarshal(object): + def test_query_valid(self): param = Parameter('param', 'query') value = 'test' @@ -55,13 +76,6 @@ class TestParameterUnmarshal(object): assert result == value - def test_query_empty(self): - param = Parameter('param', 'query') - value = '' - - with pytest.raises(EmptyParameterValue): - param.unmarshal(value) - def test_query_allow_empty_value(self): param = Parameter('param', 'query', allow_empty_value=True) value = '' @@ -69,3 +83,28 @@ class TestParameterUnmarshal(object): result = param.unmarshal(value) assert result == value + + def test_query_schema_type_invalid(self): + schema = Schema('integer', _source={'type': 'integer'}) + param = Parameter('param', 'query', schema=schema) + value = 'test' + + with pytest.raises(InvalidParameterValue): + param.unmarshal(value) + + def test_query_schema_custom_format_invalid(self): + def custom_formatter(value): + raise ValueError + schema = Schema( + 'string', + schema_format='custom', + _source={'type': 'string', 'format': 'custom'}, + ) + custom_formatters = { + 'custom': custom_formatter, + } + param = Parameter('param', 'query', schema=schema) + value = 'test' + + with pytest.raises(InvalidParameterValue): + param.unmarshal(value, custom_formatters=custom_formatters) diff --git a/tests/unit/schema/test_schemas.py b/tests/unit/schema/test_schemas.py index 0e7a2fe..4436ff8 100644 --- a/tests/unit/schema/test_schemas.py +++ b/tests/unit/schema/test_schemas.py @@ -42,6 +42,17 @@ class TestSchemaUnmarshal(object): assert result == value + @pytest.mark.parametrize('schema_type', [ + 'boolean', 'array', 'integer', 'number', + ]) + def test_non_string_empty_value(self, schema_type): + schema = Schema(schema_type) + value = '' + + result = schema.unmarshal(value) + + assert result is None + def test_string_valid(self): schema = Schema('string') value = 'test' @@ -121,19 +132,28 @@ class TestSchemaUnmarshal(object): assert result == datetime.datetime(2018, 1, 2, 0, 0) - @pytest.mark.xfail(reason="No custom formats support atm") def test_string_format_custom(self): + def custom_formatter(value): + return 'x-custom' custom_format = 'custom' schema = Schema('string', schema_format=custom_format) value = 'x' - with mock.patch.dict( - Schema.STRING_FORMAT_CAST_CALLABLE_GETTER, - {custom_format: lambda x: x + '-custom'}, - ): - result = schema.unmarshal(value) + result = schema.unmarshal( + value, custom_formatters={custom_format: custom_formatter}) - assert result == 'x-custom' + assert result == custom_formatter(value) + + def test_string_format_custom_value_error(self): + def custom_formatter(value): + raise ValueError + custom_format = 'custom' + schema = Schema('string', schema_format=custom_format) + value = 'x' + + with pytest.raises(InvalidSchemaValue): + schema.unmarshal( + value, custom_formatters={custom_format: custom_formatter}) def test_string_format_unknown(self): unknown_format = 'unknown' @@ -143,7 +163,6 @@ class TestSchemaUnmarshal(object): with pytest.raises(OpenAPISchemaError): schema.unmarshal(value) - @pytest.mark.xfail(reason="No custom formats support atm") def test_string_format_invalid_value(self): custom_format = 'custom' schema = Schema('string', schema_format=custom_format) @@ -351,6 +370,14 @@ class TestSchemaObjValidate(object): assert result is None + def test_string_format_custom_missing(self): + custom_format = 'custom' + schema = Schema('string', schema_format=custom_format) + value = 'x' + + with pytest.raises(OpenAPISchemaError): + schema.obj_validate(value) + @pytest.mark.parametrize('value', [False, True]) def test_boolean(self, value): schema = Schema('boolean')