diff --git a/openapi_core/schema/media_types/models.py b/openapi_core/schema/media_types/models.py index d47e696..760587f 100644 --- a/openapi_core/schema/media_types/models.py +++ b/openapi_core/schema/media_types/models.py @@ -32,7 +32,7 @@ class MediaType(object): deserializer = self.get_dererializer() return deserializer(value) - def unmarshal(self, value): + def unmarshal(self, value, custom_formatters=None): if not self.schema: return value @@ -42,7 +42,7 @@ class MediaType(object): raise InvalidMediaTypeValue(str(exc)) try: - unmarshalled = self.schema.unmarshal(deserialized) + unmarshalled = self.schema.unmarshal(deserialized, custom_formatters) except InvalidSchemaValue as exc: raise InvalidMediaTypeValue(str(exc)) diff --git a/openapi_core/schema/parameters/models.py b/openapi_core/schema/parameters/models.py index 7ca6657..c515128 100644 --- a/openapi_core/schema/parameters/models.py +++ b/openapi_core/schema/parameters/models.py @@ -91,7 +91,7 @@ class Parameter(object): return location[self.name] - def unmarshal(self, value): + def unmarshal(self, value, custom_formatters=None): if self.deprecated: warnings.warn( "{0} parameter is deprecated".format(self.name), @@ -112,7 +112,7 @@ class Parameter(object): raise InvalidParameterValue(str(exc)) try: - unmarshalled = self.schema.unmarshal(deserialized) + unmarshalled = self.schema.unmarshal(deserialized, custom_formatters) except InvalidSchemaValue as exc: raise InvalidParameterValue(str(exc)) diff --git a/openapi_core/schema/schemas/models.py b/openapi_core/schema/schemas/models.py index c0014c5..31139c7 100644 --- a/openapi_core/schema/schemas/models.py +++ b/openapi_core/schema/schemas/models.py @@ -1,4 +1,6 @@ """OpenAPI core schemas models module""" +import attr +import functools import logging from collections import defaultdict from datetime import date, datetime @@ -23,6 +25,11 @@ from openapi_core.schema.schemas.validators import ( log = logging.getLogger(__name__) +@attr.s +class StringFormat(object): + format = attr.ib() + validate = attr.ib() + class Schema(object): """Represents an OpenAPI Schema.""" @@ -33,18 +40,11 @@ class Schema(object): SchemaType.BOOLEAN: forcebool, } - STRING_FORMAT_CAST_CALLABLE_GETTER = { - SchemaFormat.NONE: text_type, - SchemaFormat.DATE: format_date, - SchemaFormat.DATETIME: format_datetime, - SchemaFormat.BINARY: binary_type, - } - - STRING_FORMAT_VALIDATOR_CALLABLE_GETTER = { - SchemaFormat.NONE: TypeValidator(text_type), - SchemaFormat.DATE: TypeValidator(date, exclude=datetime), - SchemaFormat.DATETIME: TypeValidator(datetime), - SchemaFormat.BINARY: TypeValidator(binary_type), + STRING_FORMAT_CALLABLE_GETTER = { + SchemaFormat.NONE: StringFormat(text_type, TypeValidator(text_type)), + SchemaFormat.DATE: StringFormat(format_date, TypeValidator(date, exclude=datetime)), + SchemaFormat.DATETIME: StringFormat(format_datetime, TypeValidator(datetime)), + SchemaFormat.BINARY: StringFormat(binary_type, TypeValidator(binary_type)), } TYPE_VALIDATOR_CALLABLE_GETTER = { @@ -99,6 +99,7 @@ class Schema(object): self._all_required_properties_cache = None self._all_optional_properties_cache = None + self.custom_formatters = None def __getitem__(self, name): return self.properties[name] @@ -173,11 +174,13 @@ class Schema(object): "Failed to cast value of {0} to {1}".format(value, self.type) ) - def unmarshal(self, value): + def unmarshal(self, value, custom_formatters=None): """Unmarshal parameter from the value.""" if self.deprecated: warnings.warn("The schema is deprecated", DeprecationWarning) + self.custom_formatters = custom_formatters + casted = self.cast(value) if casted is None and not self.required: @@ -195,15 +198,18 @@ class Schema(object): try: schema_format = SchemaFormat(self.format) except ValueError: - # @todo: implement custom format unmarshalling support - raise OpenAPISchemaError( - "Unsupported {0} format unmarshalling".format(self.format) - ) + msg = "Unsupported {0} format unmarshalling".format(self.format) + if self.custom_formatters is not None: + formatstring = self.custom_formatters.get(self.format) + if formatstring is None: + raise OpenAPISchemaError(msg) + else: + raise OpenAPISchemaError(msg) else: - formatter = self.STRING_FORMAT_CAST_CALLABLE_GETTER[schema_format] + formatstring = self.STRING_FORMAT_CALLABLE_GETTER[schema_format] try: - return formatter(value) + return formatstring.format(value) except ValueError: raise InvalidSchemaValue( "Failed to format value of {0} to {1}".format( @@ -231,7 +237,8 @@ class Schema(object): if self.items is None: raise UndefinedItemsSchema("Undefined items' schema") - return list(map(self.items.unmarshal, value)) + f = functools.partial(self.items.unmarshal, custom_formatters=self.custom_formatters) + return list(map(f, value)) def _unmarshal_object(self, value, model_factory=None): if not isinstance(value, (dict, )): @@ -286,7 +293,7 @@ class Schema(object): for prop_name in extra_props: prop_value = value[prop_name] properties[prop_name] = self.additional_properties.unmarshal( - prop_value) + prop_value, self.custom_formatters) for prop_name, prop in iteritems(all_props): try: @@ -298,7 +305,7 @@ class Schema(object): if not prop.nullable and not prop.default: continue prop_value = prop.default - properties[prop_name] = prop.unmarshal(prop_value) + properties[prop_name] = prop.unmarshal(prop_value, self.custom_formatters) self._validate_properties(properties, one_of_schema=one_of_schema) @@ -405,15 +412,18 @@ class Schema(object): try: schema_format = SchemaFormat(self.format) except ValueError: - # @todo: implement custom format validation support - raise OpenAPISchemaError( - "Unsupported {0} format validation".format(self.format) - ) + msg = "Unsupported {0} format validation".format(self.format) + if self.custom_formatters is not None: + formatstring = self.custom_formatters.get(self.format) + if formatstring is None: + raise OpenAPISchemaError(msg) + else: + raise OpenAPISchemaError(msg) else: - format_validator_callable =\ - self.STRING_FORMAT_VALIDATOR_CALLABLE_GETTER[schema_format] + formatstring =\ + self.STRING_FORMAT_CALLABLE_GETTER[schema_format] - if not format_validator_callable(value): + if not formatstring.validate(value): raise InvalidSchemaValue( "Value of {0} not valid format of {1}".format( value, self.format) diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 87b638a..08593f6 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -11,8 +11,9 @@ from openapi_core.validation.util import get_operation_pattern class RequestValidator(object): - def __init__(self, spec): + def __init__(self, spec, custom_formatters=None): self.spec = spec + self.custom_formatters = custom_formatters def validate(self, request): try: @@ -52,7 +53,7 @@ class RequestValidator(object): continue try: - value = param.unmarshal(raw_value) + value = param.unmarshal(raw_value, self.custom_formatters) except OpenAPIMappingError as exc: errors.append(exc) else: @@ -78,7 +79,7 @@ class RequestValidator(object): errors.append(exc) else: try: - body = media_type.unmarshal(raw_body) + body = media_type.unmarshal(raw_body, self.custom_formatters) except OpenAPIMappingError as exc: errors.append(exc) diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index b926fb8..4f9696b 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -6,8 +6,9 @@ from openapi_core.validation.util import get_operation_pattern class ResponseValidator(object): - def __init__(self, spec): + def __init__(self, spec, custom_formatters=None): self.spec = spec + self.custom_formatters = custom_formatters def validate(self, request, response): try: @@ -60,7 +61,7 @@ class ResponseValidator(object): errors.append(exc) else: try: - data = media_type.unmarshal(raw_data) + data = media_type.unmarshal(raw_data, self.custom_formatters) except OpenAPIMappingError as exc: errors.append(exc) diff --git a/requirements.txt b/requirements.txt index f13ec83..a5b8a67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ openapi-spec-validator six lazy-object-proxy +attrs diff --git a/requirements_2.7.txt b/requirements_2.7.txt index 19ba7bb..97f5124 100644 --- a/requirements_2.7.txt +++ b/requirements_2.7.txt @@ -4,3 +4,4 @@ lazy-object-proxy backports.functools-lru-cache backports.functools-partialmethod enum34 +attrs