review feedback

This commit is contained in:
Domen Kožar 2018-09-05 12:39:10 +01:00
parent 64628d1cc9
commit 89a53f6edc
No known key found for this signature in database
GPG key ID: C2FFBCAFD2C24246
4 changed files with 65 additions and 49 deletions

View file

@ -42,11 +42,11 @@ class MediaType(object):
raise InvalidMediaTypeValue(str(exc)) raise InvalidMediaTypeValue(str(exc))
try: try:
unmarshalled = self.schema.unmarshal(deserialized, custom_formatters) unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters)
except InvalidSchemaValue as exc: except InvalidSchemaValue as exc:
raise InvalidMediaTypeValue(str(exc)) raise InvalidMediaTypeValue(str(exc))
try: try:
return self.schema.validate(unmarshalled) return self.schema.validate(unmarshalled, custom_formatters=custom_formatters)
except InvalidSchemaValue as exc: except InvalidSchemaValue as exc:
raise InvalidMediaTypeValue(str(exc)) raise InvalidMediaTypeValue(str(exc))

View file

@ -112,11 +112,11 @@ class Parameter(object):
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))
try: try:
unmarshalled = self.schema.unmarshal(deserialized, custom_formatters) unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters)
except InvalidSchemaValue as exc: except InvalidSchemaValue as exc:
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))
try: try:
return self.schema.validate(unmarshalled) return self.schema.validate(unmarshalled, custom_formatters=custom_formatters)
except InvalidSchemaValue as exc: except InvalidSchemaValue as exc:
raise InvalidParameterValue(str(exc)) raise InvalidParameterValue(str(exc))

View file

@ -25,9 +25,10 @@ from openapi_core.schema.schemas.validators import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@attr.s @attr.s
class StringFormat(object): class Format(object):
format = attr.ib() unmarshal = attr.ib()
validate = attr.ib() validate = attr.ib()
@ -41,10 +42,10 @@ class Schema(object):
} }
STRING_FORMAT_CALLABLE_GETTER = { STRING_FORMAT_CALLABLE_GETTER = {
SchemaFormat.NONE: StringFormat(text_type, TypeValidator(text_type)), SchemaFormat.NONE: Format(text_type, TypeValidator(text_type)),
SchemaFormat.DATE: StringFormat(format_date, TypeValidator(date, exclude=datetime)), SchemaFormat.DATE: Format(format_date, TypeValidator(date, exclude=datetime)),
SchemaFormat.DATETIME: StringFormat(format_datetime, TypeValidator(datetime)), SchemaFormat.DATETIME: Format(format_datetime, TypeValidator(datetime)),
SchemaFormat.BINARY: StringFormat(binary_type, TypeValidator(binary_type)), SchemaFormat.BINARY: Format(binary_type, TypeValidator(binary_type)),
} }
TYPE_VALIDATOR_CALLABLE_GETTER = { TYPE_VALIDATOR_CALLABLE_GETTER = {
@ -99,7 +100,6 @@ class Schema(object):
self._all_required_properties_cache = None self._all_required_properties_cache = None
self._all_optional_properties_cache = None self._all_optional_properties_cache = None
self.custom_formatters = None
def __getitem__(self, name): def __getitem__(self, name):
return self.properties[name] return self.properties[name]
@ -143,25 +143,27 @@ class Schema(object):
return set(required) return set(required)
def get_cast_mapping(self): def get_cast_mapping(self, custom_formatters=None):
pass_defaults = lambda f: functools.partial(
f, custom_formatters=custom_formatters)
mapping = self.DEFAULT_CAST_CALLABLE_GETTER.copy() mapping = self.DEFAULT_CAST_CALLABLE_GETTER.copy()
mapping.update({ mapping.update({
SchemaType.STRING: self._unmarshal_string, SchemaType.STRING: pass_defaults(self._unmarshal_string),
SchemaType.ANY: self._unmarshal_any, SchemaType.ANY: pass_defaults(self._unmarshal_any),
SchemaType.ARRAY: self._unmarshal_collection, SchemaType.ARRAY: pass_defaults(self._unmarshal_collection),
SchemaType.OBJECT: self._unmarshal_object, SchemaType.OBJECT: pass_defaults(self._unmarshal_object),
}) })
return defaultdict(lambda: lambda x: x, mapping) return defaultdict(lambda: lambda x: x, mapping)
def cast(self, value): def cast(self, value, custom_formatters=None):
"""Cast value to schema type""" """Cast value to schema type"""
if value is None: if value is None:
if not self.nullable: if not self.nullable:
raise InvalidSchemaValue("Null value for non-nullable schema") raise InvalidSchemaValue("Null value for non-nullable schema")
return self.default return self.default
cast_mapping = self.get_cast_mapping() cast_mapping = self.get_cast_mapping(custom_formatters=custom_formatters)
if self.type is not SchemaType.STRING and value == '': if self.type is not SchemaType.STRING and value == '':
return None return None
@ -179,9 +181,7 @@ class Schema(object):
if self.deprecated: if self.deprecated:
warnings.warn("The schema is deprecated", DeprecationWarning) warnings.warn("The schema is deprecated", DeprecationWarning)
self.custom_formatters = custom_formatters casted = self.cast(value, custom_formatters=custom_formatters)
casted = self.cast(value)
if casted is None and not self.required: if casted is None and not self.required:
return None return None
@ -194,13 +194,13 @@ class Schema(object):
return casted return casted
def _unmarshal_string(self, value): def _unmarshal_string(self, value, custom_formatters=None):
try: try:
schema_format = SchemaFormat(self.format) schema_format = SchemaFormat(self.format)
except ValueError: except ValueError:
msg = "Unsupported {0} format unmarshalling".format(self.format) msg = "Unsupported {0} format unmarshalling".format(self.format)
if self.custom_formatters is not None: if custom_formatters is not None:
formatstring = self.custom_formatters.get(self.format) formatstring = custom_formatters.get(self.format)
if formatstring is None: if formatstring is None:
raise OpenAPISchemaError(msg) raise OpenAPISchemaError(msg)
else: else:
@ -209,14 +209,14 @@ class Schema(object):
formatstring = self.STRING_FORMAT_CALLABLE_GETTER[schema_format] formatstring = self.STRING_FORMAT_CALLABLE_GETTER[schema_format]
try: try:
return formatstring.format(value) return formatstring.unmarshal(value)
except ValueError: except ValueError:
raise InvalidSchemaValue( raise InvalidSchemaValue(
"Failed to format value of {0} to {1}".format( "Failed to format value of {0} to {1}".format(
value, self.format) value, self.format)
) )
def _unmarshal_any(self, value): def _unmarshal_any(self, value, custom_formatters=None):
types_resolve_order = [ types_resolve_order = [
SchemaType.OBJECT, SchemaType.ARRAY, SchemaType.BOOLEAN, SchemaType.OBJECT, SchemaType.ARRAY, SchemaType.BOOLEAN,
SchemaType.INTEGER, SchemaType.NUMBER, SchemaType.STRING, SchemaType.INTEGER, SchemaType.NUMBER, SchemaType.STRING,
@ -233,14 +233,16 @@ class Schema(object):
raise NoValidSchema( raise NoValidSchema(
"No valid schema found for value {0}".format(value)) "No valid schema found for value {0}".format(value))
def _unmarshal_collection(self, value): def _unmarshal_collection(self, value, custom_formatters=None):
if self.items is None: if self.items is None:
raise UndefinedItemsSchema("Undefined items' schema") raise UndefinedItemsSchema("Undefined items' schema")
f = functools.partial(self.items.unmarshal, custom_formatters=self.custom_formatters) f = functools.partial(self.items.unmarshal,
custom_formatters=custom_formatters)
return list(map(f, value)) return list(map(f, value))
def _unmarshal_object(self, value, model_factory=None): def _unmarshal_object(self, value, model_factory=None,
custom_formatters=None):
if not isinstance(value, (dict, )): if not isinstance(value, (dict, )):
raise InvalidSchemaValue( raise InvalidSchemaValue(
"Value of {0} not a dict".format(value)) "Value of {0} not a dict".format(value))
@ -252,7 +254,7 @@ class Schema(object):
for one_of_schema in self.one_of: for one_of_schema in self.one_of:
try: try:
found_props = self._unmarshal_properties( found_props = self._unmarshal_properties(
value, one_of_schema) value, one_of_schema, custom_formatters=custom_formatters)
except OpenAPISchemaError: except OpenAPISchemaError:
pass pass
else: else:
@ -267,11 +269,13 @@ class Schema(object):
"Exactly one valid schema should be valid, None found.") "Exactly one valid schema should be valid, None found.")
else: else:
properties = self._unmarshal_properties(value) properties = self._unmarshal_properties(
value, custom_formatters=custom_formatters)
return model_factory.create(properties, name=self.model) return model_factory.create(properties, name=self.model)
def _unmarshal_properties(self, value, one_of_schema=None): def _unmarshal_properties(self, value, one_of_schema=None,
custom_formatters=None):
all_props = self.get_all_properties() all_props = self.get_all_properties()
all_props_names = self.get_all_properties_names() all_props_names = self.get_all_properties_names()
all_req_props_names = self.get_all_required_properties_names() all_req_props_names = self.get_all_required_properties_names()
@ -293,7 +297,7 @@ class Schema(object):
for prop_name in extra_props: for prop_name in extra_props:
prop_value = value[prop_name] prop_value = value[prop_name]
properties[prop_name] = self.additional_properties.unmarshal( properties[prop_name] = self.additional_properties.unmarshal(
prop_value, self.custom_formatters) prop_value, custom_formatters=custom_formatters)
for prop_name, prop in iteritems(all_props): for prop_name, prop in iteritems(all_props):
try: try:
@ -305,9 +309,11 @@ class Schema(object):
if not prop.nullable and not prop.default: if not prop.nullable and not prop.default:
continue continue
prop_value = prop.default prop_value = prop.default
properties[prop_name] = prop.unmarshal(prop_value, self.custom_formatters) properties[prop_name] = prop.unmarshal(
prop_value, custom_formatters=custom_formatters)
self._validate_properties(properties, one_of_schema=one_of_schema) self._validate_properties(properties, one_of_schema=one_of_schema,
custom_formatters=custom_formatters)
return properties return properties
@ -320,9 +326,12 @@ class Schema(object):
SchemaType.NUMBER: self._validate_number, SchemaType.NUMBER: self._validate_number,
} }
return defaultdict(lambda: lambda x: x, mapping) def default(x, **kw):
return x
def validate(self, value): return defaultdict(lambda: default, mapping)
def validate(self, value, custom_formatters=None):
if value is None: if value is None:
if not self.nullable: if not self.nullable:
raise InvalidSchemaValue("Null value for non-nullable schema") raise InvalidSchemaValue("Null value for non-nullable schema")
@ -340,11 +349,11 @@ class Schema(object):
# structure validation # structure validation
validator_mapping = self.get_validator_mapping() validator_mapping = self.get_validator_mapping()
validator_callable = validator_mapping[self.type] validator_callable = validator_mapping[self.type]
validator_callable(value) validator_callable(value, custom_formatters=custom_formatters)
return value return value
def _validate_collection(self, value): def _validate_collection(self, value, custom_formatters=None):
if self.items is None: if self.items is None:
raise OpenAPISchemaError("Schema for collection not defined") raise OpenAPISchemaError("Schema for collection not defined")
@ -375,7 +384,9 @@ class Schema(object):
if self.unique_items and len(set(value)) != len(value): if self.unique_items and len(set(value)) != len(value):
raise InvalidSchemaValue("Value may not contain duplicate items") raise InvalidSchemaValue("Value may not contain duplicate items")
return list(map(self.items.validate, value)) f = functools.partial(self.items.validate,
custom_formatters=custom_formatters)
return list(map(f, value))
def _validate_number(self, value): def _validate_number(self, value):
if self.minimum is not None: if self.minimum is not None:
@ -408,13 +419,13 @@ class Schema(object):
value, self.multiple_of) value, self.multiple_of)
) )
def _validate_string(self, value): def _validate_string(self, value, custom_formatters=None):
try: try:
schema_format = SchemaFormat(self.format) schema_format = SchemaFormat(self.format)
except ValueError: except ValueError:
msg = "Unsupported {0} format validation".format(self.format) msg = "Unsupported {0} format validation".format(self.format)
if self.custom_formatters is not None: if custom_formatters is not None:
formatstring = self.custom_formatters.get(self.format) formatstring = custom_formatters.get(self.format)
if formatstring is None: if formatstring is None:
raise OpenAPISchemaError(msg) raise OpenAPISchemaError(msg)
else: else:
@ -459,14 +470,16 @@ class Schema(object):
return True return True
def _validate_object(self, value): def _validate_object(self, value, custom_formatters=None):
properties = value.__dict__ properties = value.__dict__
if self.one_of: if self.one_of:
valid_one_of_schema = None valid_one_of_schema = None
for one_of_schema in self.one_of: for one_of_schema in self.one_of:
try: try:
self._validate_properties(properties, one_of_schema) self._validate_properties(
properties, one_of_schema,
custom_formatters=custom_formatters)
except OpenAPISchemaError: except OpenAPISchemaError:
pass pass
else: else:
@ -481,7 +494,8 @@ class Schema(object):
"Exactly one valid schema should be valid, None found.") "Exactly one valid schema should be valid, None found.")
else: else:
self._validate_properties(properties) self._validate_properties(properties,
custom_formatters=custom_formatters)
if self.min_properties is not None: if self.min_properties is not None:
if self.min_properties < 0: if self.min_properties < 0:
@ -512,7 +526,8 @@ class Schema(object):
return True return True
def _validate_properties(self, value, one_of_schema=None): def _validate_properties(self, value, one_of_schema=None,
custom_formatters=None):
all_props = self.get_all_properties() all_props = self.get_all_properties()
all_props_names = self.get_all_properties_names() all_props_names = self.get_all_properties_names()
all_req_props_names = self.get_all_required_properties_names() all_req_props_names = self.get_all_required_properties_names()
@ -533,7 +548,7 @@ class Schema(object):
for prop_name in extra_props: for prop_name in extra_props:
prop_value = value[prop_name] prop_value = value[prop_name]
self.additional_properties.validate( self.additional_properties.validate(
prop_value) prop_value, custom_formatters=custom_formatters)
for prop_name, prop in iteritems(all_props): for prop_name, prop in iteritems(all_props):
try: try:
@ -545,6 +560,6 @@ class Schema(object):
if not prop.nullable and not prop.default: if not prop.nullable and not prop.default:
continue continue
prop_value = prop.default prop_value = prop.default
prop.validate(prop_value) prop.validate(prop_value, custom_formatters=custom_formatters)
return True return True

View file

@ -44,6 +44,7 @@ class PyTest(TestCommand):
'--cov', 'openapi_core', '--cov', 'openapi_core',
'--cov-report', 'term-missing', '--cov-report', 'term-missing',
'--cov-report', 'xml:reports/coverage.xml', '--cov-report', 'xml:reports/coverage.xml',
'tests',
] ]
self.test_suite = True self.test_suite = True