Sketch out custom formatters design

This commit is contained in:
Domen Kožar 2018-08-24 15:57:41 +01:00
parent 032aecdd5f
commit 64628d1cc9
No known key found for this signature in database
GPG key ID: C2FFBCAFD2C24246
7 changed files with 52 additions and 38 deletions

View file

@ -32,7 +32,7 @@ class MediaType(object):
deserializer = self.get_dererializer() deserializer = self.get_dererializer()
return deserializer(value) return deserializer(value)
def unmarshal(self, value): def unmarshal(self, value, custom_formatters=None):
if not self.schema: if not self.schema:
return value return value
@ -42,7 +42,7 @@ class MediaType(object):
raise InvalidMediaTypeValue(str(exc)) raise InvalidMediaTypeValue(str(exc))
try: try:
unmarshalled = self.schema.unmarshal(deserialized) unmarshalled = self.schema.unmarshal(deserialized, custom_formatters)
except InvalidSchemaValue as exc: except InvalidSchemaValue as exc:
raise InvalidMediaTypeValue(str(exc)) raise InvalidMediaTypeValue(str(exc))

View file

@ -91,7 +91,7 @@ class Parameter(object):
return location[self.name] return location[self.name]
def unmarshal(self, value): def unmarshal(self, value, custom_formatters=None):
if self.deprecated: if self.deprecated:
warnings.warn( warnings.warn(
"{0} parameter is deprecated".format(self.name), "{0} parameter is deprecated".format(self.name),
@ -112,7 +112,7 @@ class Parameter(object):
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))
try: try:
unmarshalled = self.schema.unmarshal(deserialized) unmarshalled = self.schema.unmarshal(deserialized, custom_formatters)
except InvalidSchemaValue as exc: except InvalidSchemaValue as exc:
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))

View file

@ -1,4 +1,6 @@
"""OpenAPI core schemas models module""" """OpenAPI core schemas models module"""
import attr
import functools
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import date, datetime from datetime import date, datetime
@ -23,6 +25,11 @@ from openapi_core.schema.schemas.validators import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@attr.s
class StringFormat(object):
format = attr.ib()
validate = attr.ib()
class Schema(object): class Schema(object):
"""Represents an OpenAPI Schema.""" """Represents an OpenAPI Schema."""
@ -33,18 +40,11 @@ class Schema(object):
SchemaType.BOOLEAN: forcebool, SchemaType.BOOLEAN: forcebool,
} }
STRING_FORMAT_CAST_CALLABLE_GETTER = { STRING_FORMAT_CALLABLE_GETTER = {
SchemaFormat.NONE: text_type, SchemaFormat.NONE: StringFormat(text_type, TypeValidator(text_type)),
SchemaFormat.DATE: format_date, SchemaFormat.DATE: StringFormat(format_date, TypeValidator(date, exclude=datetime)),
SchemaFormat.DATETIME: format_datetime, SchemaFormat.DATETIME: StringFormat(format_datetime, TypeValidator(datetime)),
SchemaFormat.BINARY: binary_type, SchemaFormat.BINARY: StringFormat(binary_type, TypeValidator(binary_type)),
}
STRING_FORMAT_VALIDATOR_CALLABLE_GETTER = {
SchemaFormat.NONE: TypeValidator(text_type),
SchemaFormat.DATE: TypeValidator(date, exclude=datetime),
SchemaFormat.DATETIME: TypeValidator(datetime),
SchemaFormat.BINARY: TypeValidator(binary_type),
} }
TYPE_VALIDATOR_CALLABLE_GETTER = { TYPE_VALIDATOR_CALLABLE_GETTER = {
@ -99,6 +99,7 @@ class Schema(object):
self._all_required_properties_cache = None self._all_required_properties_cache = None
self._all_optional_properties_cache = None self._all_optional_properties_cache = None
self.custom_formatters = None
def __getitem__(self, name): def __getitem__(self, name):
return self.properties[name] return self.properties[name]
@ -173,11 +174,13 @@ class Schema(object):
"Failed to cast value of {0} to {1}".format(value, self.type) "Failed to cast value of {0} to {1}".format(value, self.type)
) )
def unmarshal(self, value): def unmarshal(self, value, custom_formatters=None):
"""Unmarshal parameter from the value.""" """Unmarshal parameter from the value."""
if self.deprecated: if self.deprecated:
warnings.warn("The schema is deprecated", DeprecationWarning) warnings.warn("The schema is deprecated", DeprecationWarning)
self.custom_formatters = custom_formatters
casted = self.cast(value) casted = self.cast(value)
if casted is None and not self.required: if casted is None and not self.required:
@ -195,15 +198,18 @@ class Schema(object):
try: try:
schema_format = SchemaFormat(self.format) schema_format = SchemaFormat(self.format)
except ValueError: except ValueError:
# @todo: implement custom format unmarshalling support msg = "Unsupported {0} format unmarshalling".format(self.format)
raise OpenAPISchemaError( if self.custom_formatters is not None:
"Unsupported {0} format unmarshalling".format(self.format) formatstring = self.custom_formatters.get(self.format)
) if formatstring is None:
raise OpenAPISchemaError(msg)
else: else:
formatter = self.STRING_FORMAT_CAST_CALLABLE_GETTER[schema_format] raise OpenAPISchemaError(msg)
else:
formatstring = self.STRING_FORMAT_CALLABLE_GETTER[schema_format]
try: try:
return formatter(value) return formatstring.format(value)
except ValueError: except ValueError:
raise InvalidSchemaValue( raise InvalidSchemaValue(
"Failed to format value of {0} to {1}".format( "Failed to format value of {0} to {1}".format(
@ -231,7 +237,8 @@ class Schema(object):
if self.items is None: if self.items is None:
raise UndefinedItemsSchema("Undefined items' schema") raise UndefinedItemsSchema("Undefined items' schema")
return list(map(self.items.unmarshal, value)) f = functools.partial(self.items.unmarshal, custom_formatters=self.custom_formatters)
return list(map(f, value))
def _unmarshal_object(self, value, model_factory=None): def _unmarshal_object(self, value, model_factory=None):
if not isinstance(value, (dict, )): if not isinstance(value, (dict, )):
@ -286,7 +293,7 @@ class Schema(object):
for prop_name in extra_props: for prop_name in extra_props:
prop_value = value[prop_name] prop_value = value[prop_name]
properties[prop_name] = self.additional_properties.unmarshal( properties[prop_name] = self.additional_properties.unmarshal(
prop_value) prop_value, self.custom_formatters)
for prop_name, prop in iteritems(all_props): for prop_name, prop in iteritems(all_props):
try: try:
@ -298,7 +305,7 @@ class Schema(object):
if not prop.nullable and not prop.default: if not prop.nullable and not prop.default:
continue continue
prop_value = prop.default prop_value = prop.default
properties[prop_name] = prop.unmarshal(prop_value) properties[prop_name] = prop.unmarshal(prop_value, self.custom_formatters)
self._validate_properties(properties, one_of_schema=one_of_schema) self._validate_properties(properties, one_of_schema=one_of_schema)
@ -405,15 +412,18 @@ class Schema(object):
try: try:
schema_format = SchemaFormat(self.format) schema_format = SchemaFormat(self.format)
except ValueError: except ValueError:
# @todo: implement custom format validation support msg = "Unsupported {0} format validation".format(self.format)
raise OpenAPISchemaError( if self.custom_formatters is not None:
"Unsupported {0} format validation".format(self.format) formatstring = self.custom_formatters.get(self.format)
) if formatstring is None:
raise OpenAPISchemaError(msg)
else: else:
format_validator_callable =\ raise OpenAPISchemaError(msg)
self.STRING_FORMAT_VALIDATOR_CALLABLE_GETTER[schema_format] else:
formatstring =\
self.STRING_FORMAT_CALLABLE_GETTER[schema_format]
if not format_validator_callable(value): if not formatstring.validate(value):
raise InvalidSchemaValue( raise InvalidSchemaValue(
"Value of {0} not valid format of {1}".format( "Value of {0} not valid format of {1}".format(
value, self.format) value, self.format)

View file

@ -11,8 +11,9 @@ from openapi_core.validation.util import get_operation_pattern
class RequestValidator(object): class RequestValidator(object):
def __init__(self, spec): def __init__(self, spec, custom_formatters=None):
self.spec = spec self.spec = spec
self.custom_formatters = custom_formatters
def validate(self, request): def validate(self, request):
try: try:
@ -52,7 +53,7 @@ class RequestValidator(object):
continue continue
try: try:
value = param.unmarshal(raw_value) value = param.unmarshal(raw_value, self.custom_formatters)
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)
else: else:
@ -78,7 +79,7 @@ class RequestValidator(object):
errors.append(exc) errors.append(exc)
else: else:
try: try:
body = media_type.unmarshal(raw_body) body = media_type.unmarshal(raw_body, self.custom_formatters)
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)

View file

@ -6,8 +6,9 @@ from openapi_core.validation.util import get_operation_pattern
class ResponseValidator(object): class ResponseValidator(object):
def __init__(self, spec): def __init__(self, spec, custom_formatters=None):
self.spec = spec self.spec = spec
self.custom_formatters = custom_formatters
def validate(self, request, response): def validate(self, request, response):
try: try:
@ -60,7 +61,7 @@ class ResponseValidator(object):
errors.append(exc) errors.append(exc)
else: else:
try: try:
data = media_type.unmarshal(raw_data) data = media_type.unmarshal(raw_data, self.custom_formatters)
except OpenAPIMappingError as exc: except OpenAPIMappingError as exc:
errors.append(exc) errors.append(exc)

View file

@ -1,3 +1,4 @@
openapi-spec-validator openapi-spec-validator
six six
lazy-object-proxy lazy-object-proxy
attrs

View file

@ -4,3 +4,4 @@ lazy-object-proxy
backports.functools-lru-cache backports.functools-lru-cache
backports.functools-partialmethod backports.functools-partialmethod
enum34 enum34
attrs