openapi-core/openapi_core/schema/schemas/models.py
2020-01-17 14:43:18 +00:00

358 lines
13 KiB
Python

"""OpenAPI core schemas models module"""
import attr
import functools
import logging
from collections import defaultdict
import re
import warnings
from six import iteritems
from jsonschema.exceptions import ValidationError
from openapi_core.extensions.models.factories import ModelFactory
from openapi_core.schema.schemas._format import oas30_format_checker
from openapi_core.schema.schemas.enums import SchemaType
from openapi_core.schema.schemas.exceptions import (
CastError, InvalidSchemaValue,
UnmarshalValueError, UnmarshalError,
)
from openapi_core.schema.schemas.util import forcebool
from openapi_core.schema.schemas.validators import OAS30Validator
log = logging.getLogger(__name__)
@attr.s
class Format(object):
unmarshal = attr.ib()
validate = attr.ib()
class Schema(object):
"""Represents an OpenAPI Schema."""
TYPE_CAST_CALLABLE_GETTER = {
SchemaType.INTEGER: int,
SchemaType.NUMBER: float,
SchemaType.BOOLEAN: forcebool,
}
DEFAULT_UNMARSHAL_CALLABLE_GETTER = {
}
def __init__(
self, schema_type=None, model=None, properties=None, items=None,
schema_format=None, required=None, default=None, nullable=False,
enum=None, deprecated=False, all_of=None, one_of=None,
additional_properties=True, min_items=None, max_items=None,
min_length=None, max_length=None, pattern=None, unique_items=False,
minimum=None, maximum=None, multiple_of=None,
exclusive_minimum=False, exclusive_maximum=False,
min_properties=None, max_properties=None, _source=None):
self.type = SchemaType(schema_type)
self.model = model
self.properties = properties and dict(properties) or {}
self.items = items
self.format = schema_format
self.required = required or []
self.default = default
self.nullable = nullable
self.enum = enum
self.deprecated = deprecated
self.all_of = all_of and list(all_of) or []
self.one_of = one_of and list(one_of) or []
self.additional_properties = additional_properties
self.min_items = int(min_items) if min_items is not None else None
self.max_items = int(max_items) if max_items is not None else None
self.min_length = int(min_length) if min_length is not None else None
self.max_length = int(max_length) if max_length is not None else None
self.pattern = pattern and re.compile(pattern) or None
self.unique_items = unique_items
self.minimum = int(minimum) if minimum is not None else None
self.maximum = int(maximum) if maximum is not None else None
self.multiple_of = int(multiple_of)\
if multiple_of is not None else None
self.exclusive_minimum = exclusive_minimum
self.exclusive_maximum = exclusive_maximum
self.min_properties = int(min_properties)\
if min_properties is not None else None
self.max_properties = int(max_properties)\
if max_properties is not None else None
self._all_required_properties_cache = None
self._all_optional_properties_cache = None
self._source = _source
@property
def __dict__(self):
return self._source or self.to_dict()
def to_dict(self):
from openapi_core.schema.schemas.factories import SchemaDictFactory
return SchemaDictFactory().create(self)
def __getitem__(self, name):
return self.properties[name]
def get_all_properties(self):
properties = self.properties.copy()
for subschema in self.all_of:
subschema_props = subschema.get_all_properties()
properties.update(subschema_props)
return properties
def get_all_properties_names(self):
all_properties = self.get_all_properties()
return set(all_properties.keys())
def get_all_required_properties(self):
if self._all_required_properties_cache is None:
self._all_required_properties_cache =\
self._get_all_required_properties()
return self._all_required_properties_cache
def _get_all_required_properties(self):
all_properties = self.get_all_properties()
required = self.get_all_required_properties_names()
return dict(
(prop_name, val)
for prop_name, val in iteritems(all_properties)
if prop_name in required
)
def get_all_required_properties_names(self):
required = self.required[:]
for subschema in self.all_of:
subschema_req = subschema.get_all_required_properties()
required += subschema_req
return set(required)
def get_cast_mapping(self):
mapping = self.TYPE_CAST_CALLABLE_GETTER.copy()
mapping.update({
SchemaType.ARRAY: self._cast_collection,
})
return defaultdict(lambda: lambda x: x, mapping)
def cast(self, value):
"""Cast value from string to schema type"""
if value is None:
return value
cast_mapping = self.get_cast_mapping()
cast_callable = cast_mapping[self.type]
try:
return cast_callable(value)
except ValueError:
raise CastError(value, self.type)
def _cast_collection(self, value):
return list(map(self.items.cast, value))
def get_unmarshal_mapping(self, custom_formatters=None, strict=True):
primitive_unmarshallers = self.get_primitive_unmarshallers(
custom_formatters=custom_formatters)
primitive_unmarshallers_partial = dict(
(t, functools.partial(u, type_format=self.format, strict=strict))
for t, u in primitive_unmarshallers.items()
)
def pass_defaults(f):
return functools.partial(
f, custom_formatters=custom_formatters, strict=strict)
mapping = self.DEFAULT_UNMARSHAL_CALLABLE_GETTER.copy()
mapping.update(primitive_unmarshallers_partial)
mapping.update({
SchemaType.ANY: pass_defaults(self._unmarshal_any),
SchemaType.ARRAY: pass_defaults(self._unmarshal_collection),
SchemaType.OBJECT: pass_defaults(self._unmarshal_object),
})
return defaultdict(lambda: lambda x: x, mapping)
def get_validator(self, resolver=None):
return OAS30Validator(
self.__dict__,
resolver=resolver, format_checker=oas30_format_checker,
)
def validate(self, value, resolver=None):
validator = self.get_validator(resolver=resolver)
try:
return validator.validate(value)
except ValidationError:
errors_iter = validator.iter_errors(value)
raise InvalidSchemaValue(
value, self.type, schema_errors_iter=errors_iter)
def unmarshal(self, value, custom_formatters=None, strict=True):
"""Unmarshal parameter from the value."""
if self.deprecated:
warnings.warn("The schema is deprecated", DeprecationWarning)
if value is None:
if not self.nullable:
raise UnmarshalError(
"Null value for non-nullable schema", value, self.type)
return self.default
if self.enum and value not in self.enum:
raise UnmarshalError("Invalid value for enum: {0}".format(value))
unmarshal_mapping = self.get_unmarshal_mapping(
custom_formatters=custom_formatters, strict=strict)
if self.type is not SchemaType.STRING and value == '':
return None
unmarshal_callable = unmarshal_mapping[self.type]
try:
unmarshalled = unmarshal_callable(value)
except ValueError as exc:
raise UnmarshalValueError(value, self.type, exc)
return unmarshalled
def get_primitive_unmarshallers(self, **options):
from openapi_core.schema.schemas.unmarshallers import (
StringUnmarshaller, BooleanUnmarshaller, IntegerUnmarshaller,
NumberUnmarshaller,
)
unmarshallers_classes = {
SchemaType.STRING: StringUnmarshaller,
SchemaType.BOOLEAN: BooleanUnmarshaller,
SchemaType.INTEGER: IntegerUnmarshaller,
SchemaType.NUMBER: NumberUnmarshaller,
}
unmarshallers = dict(
(t, klass(**options))
for t, klass in unmarshallers_classes.items()
)
return unmarshallers
def _unmarshal_any(self, value, custom_formatters=None, strict=True):
types_resolve_order = [
SchemaType.OBJECT, SchemaType.ARRAY, SchemaType.BOOLEAN,
SchemaType.INTEGER, SchemaType.NUMBER, SchemaType.STRING,
]
unmarshal_mapping = self.get_unmarshal_mapping()
if self.one_of:
result = None
for subschema in self.one_of:
try:
unmarshalled = subschema.unmarshal(
value, custom_formatters)
except UnmarshalError:
continue
else:
if result is not None:
log.warning("multiple valid oneOf schemas found")
continue
result = unmarshalled
if result is None:
log.warning("valid oneOf schema not found")
return result
else:
for schema_type in types_resolve_order:
unmarshal_callable = unmarshal_mapping[schema_type]
try:
return unmarshal_callable(value)
except (UnmarshalError, ValueError):
continue
log.warning("failed to unmarshal any type")
return value
def _unmarshal_collection(
self, value, custom_formatters=None, strict=True):
if not isinstance(value, (list, tuple)):
raise ValueError(
"Invalid value for collection: {0}".format(value))
f = functools.partial(
self.items.unmarshal,
custom_formatters=custom_formatters, strict=strict,
)
return list(map(f, value))
def _unmarshal_object(self, value, model_factory=None,
custom_formatters=None, strict=True):
if not isinstance(value, (dict, )):
raise ValueError("Invalid value for object: {0}".format(value))
model_factory = model_factory or ModelFactory()
if self.one_of:
properties = None
for one_of_schema in self.one_of:
try:
unmarshalled = self._unmarshal_properties(
value, one_of_schema,
custom_formatters=custom_formatters,
)
except (UnmarshalError, ValueError):
pass
else:
if properties is not None:
log.warning("multiple valid oneOf schemas found")
continue
properties = unmarshalled
if properties is None:
log.warning("valid oneOf schema not found")
else:
properties = self._unmarshal_properties(
value, custom_formatters=custom_formatters)
return model_factory.create(properties, name=self.model)
def _unmarshal_properties(self, value, one_of_schema=None,
custom_formatters=None, strict=True):
all_props = self.get_all_properties()
all_props_names = self.get_all_properties_names()
all_req_props_names = self.get_all_required_properties_names()
if one_of_schema is not None:
all_props.update(one_of_schema.get_all_properties())
all_props_names |= one_of_schema.\
get_all_properties_names()
all_req_props_names |= one_of_schema.\
get_all_required_properties_names()
value_props_names = value.keys()
extra_props = set(value_props_names) - set(all_props_names)
properties = {}
if self.additional_properties is not True:
for prop_name in extra_props:
prop_value = value[prop_name]
properties[prop_name] = self.additional_properties.unmarshal(
prop_value, custom_formatters=custom_formatters)
for prop_name, prop in iteritems(all_props):
try:
prop_value = value[prop_name]
except KeyError:
if not prop.nullable and not prop.default:
continue
prop_value = prop.default
properties[prop_name] = prop.unmarshal(
prop_value, custom_formatters=custom_formatters)
return properties