From 89a53f6edcd59cab69fcf81dfb36ab003b481b5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Domen=20Ko=C5=BEar?= Date: Wed, 5 Sep 2018 12:39:10 +0100 Subject: [PATCH] review feedback --- openapi_core/schema/media_types/models.py | 4 +- openapi_core/schema/parameters/models.py | 4 +- openapi_core/schema/schemas/models.py | 105 ++++++++++++---------- setup.py | 1 + 4 files changed, 65 insertions(+), 49 deletions(-) diff --git a/openapi_core/schema/media_types/models.py b/openapi_core/schema/media_types/models.py index 760587f..e2b0d2f 100644 --- a/openapi_core/schema/media_types/models.py +++ b/openapi_core/schema/media_types/models.py @@ -42,11 +42,11 @@ class MediaType(object): raise InvalidMediaTypeValue(str(exc)) try: - unmarshalled = self.schema.unmarshal(deserialized, custom_formatters) + unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters) except InvalidSchemaValue as exc: raise InvalidMediaTypeValue(str(exc)) try: - return self.schema.validate(unmarshalled) + return self.schema.validate(unmarshalled, custom_formatters=custom_formatters) except InvalidSchemaValue as exc: raise InvalidMediaTypeValue(str(exc)) diff --git a/openapi_core/schema/parameters/models.py b/openapi_core/schema/parameters/models.py index c515128..e484ce1 100644 --- a/openapi_core/schema/parameters/models.py +++ b/openapi_core/schema/parameters/models.py @@ -112,11 +112,11 @@ class Parameter(object): raise InvalidParameterValue(str(exc)) try: - unmarshalled = self.schema.unmarshal(deserialized, custom_formatters) + unmarshalled = self.schema.unmarshal(deserialized, custom_formatters=custom_formatters) except InvalidSchemaValue as exc: raise InvalidParameterValue(str(exc)) try: - return self.schema.validate(unmarshalled) + return self.schema.validate(unmarshalled, custom_formatters=custom_formatters) except InvalidSchemaValue as exc: raise InvalidParameterValue(str(exc)) diff --git a/openapi_core/schema/schemas/models.py b/openapi_core/schema/schemas/models.py index 31139c7..d5ff5b7 100644 --- a/openapi_core/schema/schemas/models.py +++ b/openapi_core/schema/schemas/models.py @@ -25,9 +25,10 @@ from openapi_core.schema.schemas.validators import ( log = logging.getLogger(__name__) + @attr.s -class StringFormat(object): - format = attr.ib() +class Format(object): + unmarshal = attr.ib() validate = attr.ib() @@ -41,10 +42,10 @@ class Schema(object): } STRING_FORMAT_CALLABLE_GETTER = { - SchemaFormat.NONE: StringFormat(text_type, TypeValidator(text_type)), - SchemaFormat.DATE: StringFormat(format_date, TypeValidator(date, exclude=datetime)), - SchemaFormat.DATETIME: StringFormat(format_datetime, TypeValidator(datetime)), - SchemaFormat.BINARY: StringFormat(binary_type, TypeValidator(binary_type)), + SchemaFormat.NONE: Format(text_type, TypeValidator(text_type)), + SchemaFormat.DATE: Format(format_date, TypeValidator(date, exclude=datetime)), + SchemaFormat.DATETIME: Format(format_datetime, TypeValidator(datetime)), + SchemaFormat.BINARY: Format(binary_type, TypeValidator(binary_type)), } TYPE_VALIDATOR_CALLABLE_GETTER = { @@ -99,7 +100,6 @@ class Schema(object): self._all_required_properties_cache = None self._all_optional_properties_cache = None - self.custom_formatters = None def __getitem__(self, name): return self.properties[name] @@ -143,25 +143,27 @@ class Schema(object): return set(required) - def get_cast_mapping(self): + def get_cast_mapping(self, custom_formatters=None): + pass_defaults = lambda f: functools.partial( + f, custom_formatters=custom_formatters) mapping = self.DEFAULT_CAST_CALLABLE_GETTER.copy() mapping.update({ - SchemaType.STRING: self._unmarshal_string, - SchemaType.ANY: self._unmarshal_any, - SchemaType.ARRAY: self._unmarshal_collection, - SchemaType.OBJECT: self._unmarshal_object, + SchemaType.STRING: pass_defaults(self._unmarshal_string), + SchemaType.ANY: pass_defaults(self._unmarshal_any), + SchemaType.ARRAY: pass_defaults(self._unmarshal_collection), + SchemaType.OBJECT: pass_defaults(self._unmarshal_object), }) return defaultdict(lambda: lambda x: x, mapping) - def cast(self, value): + def cast(self, value, custom_formatters=None): """Cast value to schema type""" if value is None: if not self.nullable: raise InvalidSchemaValue("Null value for non-nullable schema") return self.default - cast_mapping = self.get_cast_mapping() + cast_mapping = self.get_cast_mapping(custom_formatters=custom_formatters) if self.type is not SchemaType.STRING and value == '': return None @@ -179,9 +181,7 @@ class Schema(object): if self.deprecated: warnings.warn("The schema is deprecated", DeprecationWarning) - self.custom_formatters = custom_formatters - - casted = self.cast(value) + casted = self.cast(value, custom_formatters=custom_formatters) if casted is None and not self.required: return None @@ -194,13 +194,13 @@ class Schema(object): return casted - def _unmarshal_string(self, value): + def _unmarshal_string(self, value, custom_formatters=None): try: schema_format = SchemaFormat(self.format) except ValueError: msg = "Unsupported {0} format unmarshalling".format(self.format) - if self.custom_formatters is not None: - formatstring = self.custom_formatters.get(self.format) + if custom_formatters is not None: + formatstring = custom_formatters.get(self.format) if formatstring is None: raise OpenAPISchemaError(msg) else: @@ -209,14 +209,14 @@ class Schema(object): formatstring = self.STRING_FORMAT_CALLABLE_GETTER[schema_format] try: - return formatstring.format(value) + return formatstring.unmarshal(value) except ValueError: raise InvalidSchemaValue( "Failed to format value of {0} to {1}".format( value, self.format) ) - def _unmarshal_any(self, value): + def _unmarshal_any(self, value, custom_formatters=None): types_resolve_order = [ SchemaType.OBJECT, SchemaType.ARRAY, SchemaType.BOOLEAN, SchemaType.INTEGER, SchemaType.NUMBER, SchemaType.STRING, @@ -233,14 +233,16 @@ class Schema(object): raise NoValidSchema( "No valid schema found for value {0}".format(value)) - def _unmarshal_collection(self, value): + def _unmarshal_collection(self, value, custom_formatters=None): if self.items is None: raise UndefinedItemsSchema("Undefined items' schema") - f = functools.partial(self.items.unmarshal, custom_formatters=self.custom_formatters) + f = functools.partial(self.items.unmarshal, + custom_formatters=custom_formatters) return list(map(f, value)) - def _unmarshal_object(self, value, model_factory=None): + def _unmarshal_object(self, value, model_factory=None, + custom_formatters=None): if not isinstance(value, (dict, )): raise InvalidSchemaValue( "Value of {0} not a dict".format(value)) @@ -252,7 +254,7 @@ class Schema(object): for one_of_schema in self.one_of: try: found_props = self._unmarshal_properties( - value, one_of_schema) + value, one_of_schema, custom_formatters=custom_formatters) except OpenAPISchemaError: pass else: @@ -267,11 +269,13 @@ class Schema(object): "Exactly one valid schema should be valid, None found.") else: - properties = self._unmarshal_properties(value) + properties = self._unmarshal_properties( + value, custom_formatters=custom_formatters) return model_factory.create(properties, name=self.model) - def _unmarshal_properties(self, value, one_of_schema=None): + def _unmarshal_properties(self, value, one_of_schema=None, + custom_formatters=None): all_props = self.get_all_properties() all_props_names = self.get_all_properties_names() all_req_props_names = self.get_all_required_properties_names() @@ -293,7 +297,7 @@ class Schema(object): for prop_name in extra_props: prop_value = value[prop_name] properties[prop_name] = self.additional_properties.unmarshal( - prop_value, self.custom_formatters) + prop_value, custom_formatters=custom_formatters) for prop_name, prop in iteritems(all_props): try: @@ -305,9 +309,11 @@ class Schema(object): if not prop.nullable and not prop.default: continue prop_value = prop.default - properties[prop_name] = prop.unmarshal(prop_value, self.custom_formatters) + properties[prop_name] = prop.unmarshal( + prop_value, custom_formatters=custom_formatters) - self._validate_properties(properties, one_of_schema=one_of_schema) + self._validate_properties(properties, one_of_schema=one_of_schema, + custom_formatters=custom_formatters) return properties @@ -320,9 +326,12 @@ class Schema(object): SchemaType.NUMBER: self._validate_number, } - return defaultdict(lambda: lambda x: x, mapping) + def default(x, **kw): + return x - def validate(self, value): + return defaultdict(lambda: default, mapping) + + def validate(self, value, custom_formatters=None): if value is None: if not self.nullable: raise InvalidSchemaValue("Null value for non-nullable schema") @@ -340,11 +349,11 @@ class Schema(object): # structure validation validator_mapping = self.get_validator_mapping() validator_callable = validator_mapping[self.type] - validator_callable(value) + validator_callable(value, custom_formatters=custom_formatters) return value - def _validate_collection(self, value): + def _validate_collection(self, value, custom_formatters=None): if self.items is None: raise OpenAPISchemaError("Schema for collection not defined") @@ -375,7 +384,9 @@ class Schema(object): if self.unique_items and len(set(value)) != len(value): raise InvalidSchemaValue("Value may not contain duplicate items") - return list(map(self.items.validate, value)) + f = functools.partial(self.items.validate, + custom_formatters=custom_formatters) + return list(map(f, value)) def _validate_number(self, value): if self.minimum is not None: @@ -408,13 +419,13 @@ class Schema(object): value, self.multiple_of) ) - def _validate_string(self, value): + def _validate_string(self, value, custom_formatters=None): try: schema_format = SchemaFormat(self.format) except ValueError: msg = "Unsupported {0} format validation".format(self.format) - if self.custom_formatters is not None: - formatstring = self.custom_formatters.get(self.format) + if custom_formatters is not None: + formatstring = custom_formatters.get(self.format) if formatstring is None: raise OpenAPISchemaError(msg) else: @@ -459,14 +470,16 @@ class Schema(object): return True - def _validate_object(self, value): + def _validate_object(self, value, custom_formatters=None): properties = value.__dict__ if self.one_of: valid_one_of_schema = None for one_of_schema in self.one_of: try: - self._validate_properties(properties, one_of_schema) + self._validate_properties( + properties, one_of_schema, + custom_formatters=custom_formatters) except OpenAPISchemaError: pass else: @@ -481,7 +494,8 @@ class Schema(object): "Exactly one valid schema should be valid, None found.") else: - self._validate_properties(properties) + self._validate_properties(properties, + custom_formatters=custom_formatters) if self.min_properties is not None: if self.min_properties < 0: @@ -512,7 +526,8 @@ class Schema(object): return True - def _validate_properties(self, value, one_of_schema=None): + def _validate_properties(self, value, one_of_schema=None, + custom_formatters=None): all_props = self.get_all_properties() all_props_names = self.get_all_properties_names() all_req_props_names = self.get_all_required_properties_names() @@ -533,7 +548,7 @@ class Schema(object): for prop_name in extra_props: prop_value = value[prop_name] self.additional_properties.validate( - prop_value) + prop_value, custom_formatters=custom_formatters) for prop_name, prop in iteritems(all_props): try: @@ -545,6 +560,6 @@ class Schema(object): if not prop.nullable and not prop.default: continue prop_value = prop.default - prop.validate(prop_value) + prop.validate(prop_value, custom_formatters=custom_formatters) return True diff --git a/setup.py b/setup.py index 483c4f9..42ea7dc 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ class PyTest(TestCommand): '--cov', 'openapi_core', '--cov-report', 'term-missing', '--cov-report', 'xml:reports/coverage.xml', + 'tests', ] self.test_suite = True