Spec path

This commit is contained in:
p1c2u 2021-04-23 12:36:27 +01:00
parent b79c49420e
commit 35f8e28157
38 changed files with 1530 additions and 723 deletions

View file

@ -14,7 +14,7 @@ class PrimitiveCaster(object):
try: try:
return self.caster_callable(value) return self.caster_callable(value)
except (ValueError, TypeError): except (ValueError, TypeError):
raise CastError(value, self.schema.type.value) raise CastError(value, self.schema['type'])
class DummyCaster(object): class DummyCaster(object):
@ -31,7 +31,7 @@ class ArrayCaster(object):
@property @property
def items_caster(self): def items_caster(self):
return self.casters_factory.create(self.schema.items) return self.casters_factory.create(self.schema / 'items')
def __call__(self, value): def __call__(self, value):
if value in (None, NoValue): if value in (None, NoValue):

View file

@ -1,5 +1,3 @@
from openapi_core.schema.schemas.enums import SchemaType
from openapi_core.casting.schemas.casters import ( from openapi_core.casting.schemas.casters import (
PrimitiveCaster, DummyCaster, ArrayCaster PrimitiveCaster, DummyCaster, ArrayCaster
) )
@ -9,23 +7,24 @@ from openapi_core.casting.schemas.util import forcebool
class SchemaCastersFactory(object): class SchemaCastersFactory(object):
DUMMY_CASTERS = [ DUMMY_CASTERS = [
SchemaType.STRING, SchemaType.OBJECT, SchemaType.ANY, 'string', 'object', 'any',
] ]
PRIMITIVE_CASTERS = { PRIMITIVE_CASTERS = {
SchemaType.INTEGER: int, 'integer': int,
SchemaType.NUMBER: float, 'number': float,
SchemaType.BOOLEAN: forcebool, 'boolean': forcebool,
} }
COMPLEX_CASTERS = { COMPLEX_CASTERS = {
SchemaType.ARRAY: ArrayCaster, 'array': ArrayCaster,
} }
def create(self, schema): def create(self, schema):
if schema.type in self.DUMMY_CASTERS: schema_type = schema.getkey('type', 'any')
if schema_type in self.DUMMY_CASTERS:
return DummyCaster() return DummyCaster()
elif schema.type in self.PRIMITIVE_CASTERS: elif schema_type in self.PRIMITIVE_CASTERS:
caster_callable = self.PRIMITIVE_CASTERS[schema.type] caster_callable = self.PRIMITIVE_CASTERS[schema_type]
return PrimitiveCaster(schema, caster_callable) return PrimitiveCaster(schema, caster_callable)
elif schema.type in self.COMPLEX_CASTERS: elif schema_type in self.COMPLEX_CASTERS:
caster_class = self.COMPLEX_CASTERS[schema.type] caster_class = self.COMPLEX_CASTERS[schema_type]
return caster_class(schema, self) return caster_class(schema, self)

View file

@ -20,11 +20,11 @@ class MediaTypeDeserializersFactory(object):
custom_deserializers = {} custom_deserializers = {}
self.custom_deserializers = custom_deserializers self.custom_deserializers = custom_deserializers
def create(self, media_type): def create(self, mimetype):
deserialize_callable = self.get_deserializer_callable( deserialize_callable = self.get_deserializer_callable(
media_type.mimetype) mimetype)
return PrimitiveDeserializer( return PrimitiveDeserializer(
media_type.mimetype, deserialize_callable) mimetype, deserialize_callable)
def get_deserializer_callable(self, mimetype): def get_deserializer_callable(self, mimetype):
if mimetype in self.custom_deserializers: if mimetype in self.custom_deserializers:

View file

@ -3,6 +3,7 @@ from openapi_core.deserializing.parameters.exceptions import (
EmptyParameterValue, EmptyParameterValue,
) )
from openapi_core.schema.parameters.enums import ParameterLocation from openapi_core.schema.parameters.enums import ParameterLocation
from openapi_core.spec.parameters import get_aslist, get_explode, get_style
class PrimitiveDeserializer(object): class PrimitiveDeserializer(object):
@ -11,15 +12,20 @@ class PrimitiveDeserializer(object):
self.param = param self.param = param
self.deserializer_callable = deserializer_callable self.deserializer_callable = deserializer_callable
def __call__(self, value): self.aslist = get_aslist(self.param)
if (self.param.location == ParameterLocation.QUERY and value == "" and self.explode = get_explode(self.param)
not self.param.allow_empty_value): self.style = get_style(self.param)
raise EmptyParameterValue(
value, self.param.style, self.param.name)
if not self.param.aslist or self.param.explode: def __call__(self, value):
style = get_style(self.param)
if (self.param['in'] == 'query' and value == "" and
not self.param.getkey('allowEmptyValue', False)):
raise EmptyParameterValue(
value, self.style, self.param['name'])
if not self.aslist or self.explode:
return value return value
try: try:
return self.deserializer_callable(value) return self.deserializer_callable(value)
except (ValueError, TypeError, AttributeError): except (ValueError, TypeError, AttributeError):
raise DeserializeError(value, self.param.style) raise DeserializeError(value, self.style)

View file

@ -3,24 +3,26 @@ import warnings
from openapi_core.deserializing.parameters.deserializers import ( from openapi_core.deserializing.parameters.deserializers import (
PrimitiveDeserializer, PrimitiveDeserializer,
) )
from openapi_core.schema.parameters.enums import ParameterStyle from openapi_core.schema.parameters import get_style
class ParameterDeserializersFactory(object): class ParameterDeserializersFactory(object):
PARAMETER_STYLE_DESERIALIZERS = { PARAMETER_STYLE_DESERIALIZERS = {
ParameterStyle.FORM: lambda x: x.split(','), 'form': lambda x: x.split(','),
ParameterStyle.SIMPLE: lambda x: x.split(','), 'simple': lambda x: x.split(','),
ParameterStyle.SPACE_DELIMITED: lambda x: x.split(' '), 'spaceDelimited': lambda x: x.split(' '),
ParameterStyle.PIPE_DELIMITED: lambda x: x.split('|'), 'pipeDelimited': lambda x: x.split('|'),
} }
def create(self, param): def create(self, param):
if param.deprecated: if param.getkey('deprecated', False):
warnings.warn( warnings.warn(
"{0} parameter is deprecated".format(param.name), "{0} parameter is deprecated".format(param['name']),
DeprecationWarning, DeprecationWarning,
) )
deserialize_callable = self.PARAMETER_STYLE_DESERIALIZERS[param.style] style = get_style(param)
deserialize_callable = self.PARAMETER_STYLE_DESERIALIZERS[style]
return PrimitiveDeserializer(param, deserialize_callable) return PrimitiveDeserializer(param, deserialize_callable)

View file

@ -0,0 +1,98 @@
import sys
class Spec(object):
sep = '/'
def __new__(cls, *args):
return cls._from_parts(args)
@classmethod
def _parse_args(cls, args):
# This is useful when you don't want to create an instance, just
# canonicalize some constructor arguments.
parts = []
for a in args:
if isinstance(a, Spec):
parts += a._parts
else:
if isinstance(a, str):
# Force-cast str subclasses to str (issue #21127)
parts.append(str(a))
else:
raise TypeError(
"argument should be a str object or a Spec "
"object returning str, not %r"
% type(a))
return cls.parse_parts(parts)
@classmethod
def parse_parts(cls, parts):
parsed = []
sep = cls.sep
root = ''
it = reversed(parts)
for part in it:
if not part:
continue
root, rel = cls.splitroot(part)
if sep in rel:
for x in reversed(rel.split(sep)):
if x and x != '.':
parsed.append(sys.intern(x))
else:
if rel and rel != '.':
parsed.append(sys.intern(rel))
parsed.reverse()
return root, parsed
@classmethod
def splitroot(cls, part, sep=sep):
if part and part[0] == sep:
stripped_part = part.lstrip(sep)
# According to POSIX path resolution:
# http://pubs.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap04.html#tag_04_11
# "A pathname that begins with two successive slashes may be
# interpreted in an implementation-defined manner, although more
# than two leading slashes shall be treated as a single slash".
if len(part) - len(stripped_part) == 2:
return sep * 2, stripped_part
else:
return sep, stripped_part
else:
return '', part
@classmethod
def _from_parts(cls, args):
self = object.__new__(cls)
root, parts = cls._parse_args(args)
self._root = root
self._parts = parts
return self
@classmethod
def _from_parsed_parts(cls, root, parts):
self = object.__new__(cls)
self._root = root
self._parts = parts
return self
def join_parsed_parts(self, root, parts, root2, parts2):
"""
Join the two paths represented by the respective
(root, parts) tuples. Return a new (root, parts) tuple.
"""
if root2:
return root2, root2 + parts2[1:]
elif parts:
return root, parts + parts2
return root2, parts2
def _make_child(self, args):
root, parts = self._parse_args(args)
root, parts = self.join_parsed_parts(
self._root, self._parts, root, parts)
return self._from_parsed_parts(root, parts)
def __truediv__(self, key):
return self._make_child((key,))

View file

@ -1,6 +1,8 @@
"""OpenAPI core servers models module""" """OpenAPI core servers models module"""
from six import iteritems from six import iteritems
from openapi_core.schema.servers.utils import is_absolute
class Server(object): class Server(object):
@ -30,7 +32,7 @@ class Server(object):
def is_absolute(self, url=None): def is_absolute(self, url=None):
if url is None: if url is None:
url = self.url url = self.url
return url.startswith('//') or '://' in url return is_absolute(url)
class ServerVariable(object): class ServerVariable(object):

View file

@ -0,0 +1,2 @@
def is_absolute(url):
return url.startswith('//') or '://' in url

View file

@ -3,6 +3,7 @@ from jsonschema.validators import RefResolver
from openapi_spec_validator import ( from openapi_spec_validator import (
default_handlers, openapi_v3_spec_validator, default_handlers, openapi_v3_spec_validator,
) )
from openapi_spec_validator.validators import Dereferencer
from openapi_core.schema.specs.factories import SpecFactory from openapi_core.schema.specs.factories import SpecFactory
@ -16,5 +17,8 @@ def create_spec(
spec_resolver = RefResolver( spec_resolver = RefResolver(
spec_url, spec_dict, handlers=handlers) spec_url, spec_dict, handlers=handlers)
dereferencer = Dereferencer(spec_resolver)
from openapi_core.spec.paths import SpecPath
return SpecPath.from_spec(spec_dict, dereferencer)
spec_factory = SpecFactory(spec_resolver) spec_factory = SpecFactory(spec_resolver)
return spec_factory.create(spec_dict, spec_url=spec_url) return spec_factory.create(spec_dict, spec_url=spec_url)

View file

@ -7,13 +7,14 @@ from openapi_core.security.providers import (
class SecurityProviderFactory(object): class SecurityProviderFactory(object):
PROVIDERS = { PROVIDERS = {
SecuritySchemeType.API_KEY: ApiKeyProvider, 'apiKey': ApiKeyProvider,
SecuritySchemeType.HTTP: HttpProvider, 'http': HttpProvider,
} }
def create(self, scheme): def create(self, scheme):
if scheme.type == SecuritySchemeType.API_KEY: scheme_type = scheme['type']
if scheme_type == 'apiKey':
return ApiKeyProvider(scheme) return ApiKeyProvider(scheme)
elif scheme.type == SecuritySchemeType.HTTP: elif scheme_type == 'http':
return HttpProvider(scheme) return HttpProvider(scheme)
return UnsupportedProvider(scheme) return UnsupportedProvider(scheme)

View file

@ -18,10 +18,12 @@ class UnsupportedProvider(BaseProvider):
class ApiKeyProvider(BaseProvider): class ApiKeyProvider(BaseProvider):
def __call__(self, request): def __call__(self, request):
source = getattr(request.parameters, self.scheme.apikey_in.value) name = self.scheme['name']
if self.scheme.name not in source: location = self.scheme['in']
source = getattr(request.parameters, location)
if name not in source:
raise SecurityError("Missing api key parameter.") raise SecurityError("Missing api key parameter.")
return source.get(self.scheme.name) return source[name]
class HttpProvider(BaseProvider): class HttpProvider(BaseProvider):
@ -35,7 +37,8 @@ class HttpProvider(BaseProvider):
except ValueError: except ValueError:
raise SecurityError('Could not parse authorization header.') raise SecurityError('Could not parse authorization header.')
if auth_type.lower() != self.scheme.scheme.value: scheme = self.scheme['scheme']
if auth_type.lower() != scheme:
raise SecurityError( raise SecurityError(
'Unknown authorization method %s' % auth_type) 'Unknown authorization method %s' % auth_type)

View file

View file

@ -0,0 +1,23 @@
from contextlib import contextmanager
from dictpath.accessors import DictOrListAccessor
class SpecAccessor(DictOrListAccessor):
def __init__(self, dict_or_list, dereferencer):
super(SpecAccessor, self).__init__(dict_or_list)
self.dereferencer = dereferencer
@contextmanager
def open(self, parts):
content = self.dict_or_list
for part in parts:
content = content[part]
if '$ref' in content:
content = self.dereferencer.dereference(
content)
try:
yield content
finally:
pass

View file

@ -0,0 +1,24 @@
def get_aslist(param):
return (
param.get('schema', None) and
param['schema']['type'] in ['array', 'object']
)
def get_style(param):
if 'style' in param:
return param['style']
# determine default
return (
'simple' if param['in'] in ['path', 'header'] else 'form'
)
def get_explode(param):
if 'explode' in param:
return param['explode']
#determine default
style = get_style(param)
return style == 'form'

View file

@ -0,0 +1,16 @@
from dictpath.paths import AccessorPath
from openapi_core.spec.accessors import SpecAccessor
SPEC_SEPARATOR = '#'
class SpecPath(AccessorPath):
@classmethod
def from_spec(
cls, spec_dict, dereferencer=None, *args,
separator=SPEC_SEPARATOR,
):
accessor = SpecAccessor(spec_dict, dereferencer)
return cls(accessor, *args, separator=separator)

View file

@ -0,0 +1,20 @@
from six import iteritems
def get_all_properties(schema):
properties = schema.get('properties', {})
properties_dict = dict(iteritems(properties))
if 'allOf'not in schema:
return properties_dict
for subschema in schema / 'allOf':
subschema_props = get_all_properties(subschema)
properties_dict.update(subschema_props)
return properties_dict
def get_all_properties_names(schema):
all_properties = get_all_properties(schema)
return set(all_properties.keys())

View file

@ -0,0 +1,18 @@
from six import iteritems
def get_server_default_variables(server):
if 'variables' not in server:
return {}
defaults = {}
variables = server / 'variables'
for name, variable in iteritems(variables):
defaults[name] = variable['default']
return defaults
def get_server_url(server, **variables):
if not variables:
variables = get_server_default_variables(server)
return server['url'].format(**variables)

View file

@ -0,0 +1,6 @@
from openapi_core.spec.servers import get_server_url
def get_spec_url(spec, index=0):
servers = spec / 'servers'
return get_server_url(servers / 0)

View file

@ -12,13 +12,11 @@ class MediaTypeFinder(object):
self.content = content self.content = content
def find(self, request): def find(self, request):
try: if request.mimetype in self.content:
return self.content[request.mimetype] return self.content / request.mimetype, request.mimetype
except KeyError:
pass
for key, value in iteritems(self.content): for key, value in self.content.items():
if fnmatch.fnmatch(request.mimetype, key): if fnmatch.fnmatch(request.mimetype, key):
return value return value, key
raise MediaTypeNotFound(request.mimetype, list(self.content.keys())) raise MediaTypeNotFound(request.mimetype, list(self.content.keys()))

View file

@ -3,6 +3,7 @@ from more_itertools import peekable
from six import iteritems from six import iteritems
from six.moves.urllib.parse import urljoin, urlparse from six.moves.urllib.parse import urljoin, urlparse
from openapi_core.schema.servers.utils import is_absolute
from openapi_core.templating.datatypes import TemplateResult from openapi_core.templating.datatypes import TemplateResult
from openapi_core.templating.util import parse, search from openapi_core.templating.util import parse, search
from openapi_core.templating.paths.exceptions import ( from openapi_core.templating.paths.exceptions import (
@ -40,7 +41,8 @@ class PathFinder(object):
def _get_paths_iter(self, full_url_pattern): def _get_paths_iter(self, full_url_pattern):
template_paths = [] template_paths = []
for path_pattern, path in iteritems(self.spec.paths): paths = self.spec / 'paths'
for path_pattern, path in paths.items():
# simple path. # simple path.
# Return right away since it is always the most concrete # Return right away since it is always the most concrete
if full_url_pattern.endswith(path_pattern): if full_url_pattern.endswith(path_pattern):
@ -59,22 +61,24 @@ class PathFinder(object):
def _get_operations_iter(self, request_method, paths_iter): def _get_operations_iter(self, request_method, paths_iter):
for path, path_result in paths_iter: for path, path_result in paths_iter:
if request_method not in path.operations: if request_method not in path:
continue continue
operation = path.operations[request_method] operation = path / request_method
yield (path, operation, path_result) yield (path, operation, path_result)
def _get_servers_iter(self, full_url_pattern, ooperations_iter): def _get_servers_iter(self, full_url_pattern, ooperations_iter):
for path, operation, path_result in ooperations_iter: for path, operation, path_result in ooperations_iter:
servers = path.servers or operation.servers or self.spec.servers servers = path.get('servers', None) or \
operation.get('servers', None) or \
self.spec.get('servers', [{'url': '/'}])
for server in servers: for server in servers:
server_url_pattern = full_url_pattern.rsplit( server_url_pattern = full_url_pattern.rsplit(
path_result.resolved, 1)[0] path_result.resolved, 1)[0]
server_url = server.url server_url = server['url']
if not server.is_absolute(): if not is_absolute(server_url):
# relative to absolute url # relative to absolute url
if self.base_url is not None: if self.base_url is not None:
server_url = urljoin(self.base_url, server.url) server_url = urljoin(self.base_url, server['url'])
# if no base url check only path part # if no base url check only path part
else: else:
server_url_pattern = urlparse(server_url_pattern).path server_url_pattern = urlparse(server_url_pattern).path
@ -82,17 +86,17 @@ class PathFinder(object):
server_url = server_url[:-1] server_url = server_url[:-1]
# simple path # simple path
if server_url_pattern == server_url: if server_url_pattern == server_url:
server_result = TemplateResult(server.url, {}) server_result = TemplateResult(server['url'], {})
yield ( yield (
path, operation, server, path, operation, server,
path_result, server_result, path_result, server_result,
) )
# template path # template path
else: else:
result = parse(server.url, server_url_pattern) result = parse(server['url'], server_url_pattern)
if result: if result:
server_result = TemplateResult( server_result = TemplateResult(
server.url, result.named) server['url'], result.named)
yield ( yield (
path, operation, server, path, operation, server,
path_result, server_result, path_result, server_result,

View file

@ -7,17 +7,15 @@ class ResponseFinder(object):
self.responses = responses self.responses = responses
def find(self, http_status='default'): def find(self, http_status='default'):
try: if http_status in self.responses:
return self.responses[http_status] return self.responses / http_status
except KeyError:
pass
# try range # try range
http_status_range = '{0}XX'.format(http_status[0]) http_status_range = '{0}XX'.format(http_status[0])
if http_status_range in self.responses: if http_status_range in self.responses:
return self.responses[http_status_range] return self.responses / http_status_range
if 'default' not in self.responses: if 'default' not in self.responses:
raise ResponseNotFound(http_status, self.responses) raise ResponseNotFound(http_status, self.responses)
return self.responses['default'] return self.responses / 'default'

View file

@ -18,15 +18,15 @@ from openapi_core.unmarshalling.schemas.unmarshallers import (
class SchemaUnmarshallersFactory(object): class SchemaUnmarshallersFactory(object):
PRIMITIVE_UNMARSHALLERS = { PRIMITIVE_UNMARSHALLERS = {
SchemaType.STRING: StringUnmarshaller, 'string': StringUnmarshaller,
SchemaType.INTEGER: IntegerUnmarshaller, 'integer': IntegerUnmarshaller,
SchemaType.NUMBER: NumberUnmarshaller, 'number': NumberUnmarshaller,
SchemaType.BOOLEAN: BooleanUnmarshaller, 'boolean': BooleanUnmarshaller,
} }
COMPLEX_UNMARSHALLERS = { COMPLEX_UNMARSHALLERS = {
SchemaType.ARRAY: ArrayUnmarshaller, 'array': ArrayUnmarshaller,
SchemaType.OBJECT: ObjectUnmarshaller, 'object': ObjectUnmarshaller,
SchemaType.ANY: AnyUnmarshaller, 'any': AnyUnmarshaller,
} }
CONTEXT_VALIDATION = { CONTEXT_VALIDATION = {
@ -46,12 +46,13 @@ class SchemaUnmarshallersFactory(object):
def create(self, schema, type_override=None): def create(self, schema, type_override=None):
"""Create unmarshaller from the schema.""" """Create unmarshaller from the schema."""
if not isinstance(schema, Schema): if schema is None:
raise TypeError("schema not type of Schema") raise TypeError("Invalid schema")
if schema.deprecated:
if schema.getkey('deprecated', False):
warnings.warn("The schema is deprecated", DeprecationWarning) warnings.warn("The schema is deprecated", DeprecationWarning)
schema_type = type_override or schema.type schema_type = type_override or schema.getkey('type', 'any')
if schema_type in self.PRIMITIVE_UNMARSHALLERS: if schema_type in self.PRIMITIVE_UNMARSHALLERS:
klass = self.PRIMITIVE_UNMARSHALLERS[schema_type] klass = self.PRIMITIVE_UNMARSHALLERS[schema_type]
kwargs = dict(schema=schema) kwargs = dict(schema=schema)
@ -63,10 +64,11 @@ class SchemaUnmarshallersFactory(object):
context=self.context, context=self.context,
) )
formatter = self.get_formatter(klass.FORMATTERS, schema.format) schema_format = schema.getkey('format')
formatter = self.get_formatter(klass.FORMATTERS, schema_format)
if formatter is None: if formatter is None:
raise FormatterNotFoundError(schema.format) raise FormatterNotFoundError(schema_format)
validator = self.get_validator(schema) validator = self.get_validator(schema)
@ -87,4 +89,5 @@ class SchemaUnmarshallersFactory(object):
} }
if self.context is not None: if self.context is not None:
kwargs[self.CONTEXT_VALIDATION[self.context]] = True kwargs[self.CONTEXT_VALIDATION[self.context]] = True
return OAS30Validator(schema.__dict__, **kwargs) with schema.open() as schema_dict:
return OAS30Validator(schema_dict, **kwargs)

View file

@ -15,6 +15,9 @@ from openapi_core.extensions.models.factories import ModelFactory
from openapi_core.schema.schemas.enums import SchemaFormat, SchemaType from openapi_core.schema.schemas.enums import SchemaFormat, SchemaType
from openapi_core.schema.schemas.models import Schema from openapi_core.schema.schemas.models import Schema
from openapi_core.schema.schemas.types import NoValue from openapi_core.schema.schemas.types import NoValue
from openapi_core.spec.schemas import (
get_all_properties, get_all_properties_names
)
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
from openapi_core.unmarshalling.schemas.exceptions import ( from openapi_core.unmarshalling.schemas.exceptions import (
UnmarshalError, ValidateError, InvalidSchemaValue, UnmarshalError, ValidateError, InvalidSchemaValue,
@ -40,7 +43,7 @@ class PrimitiveTypeUnmarshaller(object):
def __call__(self, value=NoValue): def __call__(self, value=NoValue):
if value is NoValue: if value is NoValue:
value = self.schema.default value = self.schema.getkey('default')
if value is None: if value is None:
return return
@ -51,21 +54,24 @@ class PrimitiveTypeUnmarshaller(object):
def _formatter_validate(self, value): def _formatter_validate(self, value):
result = self.formatter.validate(value) result = self.formatter.validate(value)
if not result: if not result:
raise InvalidSchemaValue(value, self.schema.type) schema_type = self.schema.getkey('type', 'any')
raise InvalidSchemaValue(value, schema_type)
def validate(self, value): def validate(self, value):
errors_iter = self.validator.iter_errors(value) errors_iter = self.validator.iter_errors(value)
errors = tuple(errors_iter) errors = tuple(errors_iter)
if errors: if errors:
schema_type = self.schema.getkey('type', 'any')
raise InvalidSchemaValue( raise InvalidSchemaValue(
value, self.schema.type, schema_errors=errors) value, schema_type, schema_errors=errors)
def unmarshal(self, value): def unmarshal(self, value):
try: try:
return self.formatter.unmarshal(value) return self.formatter.unmarshal(value)
except ValueError as exc: except ValueError as exc:
schema_format = self.schema.getkey('format')
raise InvalidSchemaFormatValue( raise InvalidSchemaFormatValue(
value, self.schema.format, exc) value, schema_format, exc)
class StringUnmarshaller(PrimitiveTypeUnmarshaller): class StringUnmarshaller(PrimitiveTypeUnmarshaller):
@ -140,11 +146,11 @@ class ArrayUnmarshaller(ComplexUnmarshaller):
@property @property
def items_unmarshaller(self): def items_unmarshaller(self):
return self.unmarshallers_factory.create(self.schema.items) return self.unmarshallers_factory.create(self.schema / 'items')
def __call__(self, value=NoValue): def __call__(self, value=NoValue):
value = super(ArrayUnmarshaller, self).__call__(value) value = super(ArrayUnmarshaller, self).__call__(value)
if value is None and self.schema.nullable: if value is None and self.schema.getkey('nullable', False):
return None return None
return list(map(self.items_unmarshaller, value)) return list(map(self.items_unmarshaller, value))
@ -170,9 +176,9 @@ class ObjectUnmarshaller(ComplexUnmarshaller):
return self._unmarshal_object(value) return self._unmarshal_object(value)
def _unmarshal_object(self, value=NoValue): def _unmarshal_object(self, value=NoValue):
if self.schema.one_of: if 'oneOf' in self.schema:
properties = None properties = None
for one_of_schema in self.schema.one_of: for one_of_schema in self.schema / 'oneOf':
try: try:
unmarshalled = self._unmarshal_properties( unmarshalled = self._unmarshal_properties(
value, one_of_schema) value, one_of_schema)
@ -190,46 +196,49 @@ class ObjectUnmarshaller(ComplexUnmarshaller):
else: else:
properties = self._unmarshal_properties(value) properties = self._unmarshal_properties(value)
if 'x-model' in self.schema.extensions: if 'x-model' in self.schema:
extension = self.schema.extensions['x-model'] name = self.schema['x-model']
return self.model_factory.create(properties, name=extension.value) return self.model_factory.create(properties, name=name)
return properties return properties
def _unmarshal_properties(self, value=NoValue, one_of_schema=None): def _unmarshal_properties(self, value=NoValue, one_of_schema=None):
all_props = self.schema.get_all_properties() all_props = get_all_properties(self.schema)
all_props_names = self.schema.get_all_properties_names() all_props_names = get_all_properties_names(self.schema)
if one_of_schema is not None: if one_of_schema is not None:
all_props.update(one_of_schema.get_all_properties()) all_props.update(get_all_properties(one_of_schema))
all_props_names |= one_of_schema.\ all_props_names |= get_all_properties_names(one_of_schema)
get_all_properties_names()
value_props_names = value.keys() value_props_names = value.keys()
extra_props = set(value_props_names) - set(all_props_names) extra_props = set(value_props_names) - set(all_props_names)
properties = {} properties = {}
if isinstance(self.schema.additional_properties, Schema): additional_properties = self.schema.getkey('additionalProperties', True)
if isinstance(additional_properties, dict):
additional_prop_schema = self.schema / 'additionalProperties'
for prop_name in extra_props: for prop_name in extra_props:
prop_value = value[prop_name] prop_value = value[prop_name]
properties[prop_name] = self.unmarshallers_factory.create( properties[prop_name] = self.unmarshallers_factory.create(
self.schema.additional_properties)(prop_value) additional_prop_schema)(prop_value)
elif self.schema.additional_properties is True: elif additional_properties is True:
for prop_name in extra_props: for prop_name in extra_props:
prop_value = value[prop_name] prop_value = value[prop_name]
properties[prop_name] = prop_value properties[prop_name] = prop_value
for prop_name, prop in iteritems(all_props): for prop_name, prop in iteritems(all_props):
if self.context == UnmarshalContext.REQUEST and prop.read_only: read_only = prop.getkey('readOnly', False)
if self.context == UnmarshalContext.REQUEST and read_only:
continue continue
if self.context == UnmarshalContext.RESPONSE and prop.write_only: write_only = prop.getkey('writeOnly', False)
if self.context == UnmarshalContext.RESPONSE and write_only:
continue continue
try: try:
prop_value = value[prop_name] prop_value = value[prop_name]
except KeyError: except KeyError:
if prop.default is NoValue: if 'default' not in prop:
continue continue
prop_value = prop.default prop_value = prop['default']
properties[prop_name] = self.unmarshallers_factory.create( properties[prop_name] = self.unmarshallers_factory.create(
prop)(prop_value) prop)(prop_value)
@ -244,8 +253,8 @@ class AnyUnmarshaller(ComplexUnmarshaller):
} }
SCHEMA_TYPES_ORDER = [ SCHEMA_TYPES_ORDER = [
SchemaType.OBJECT, SchemaType.ARRAY, SchemaType.BOOLEAN, 'object', 'array', 'boolean',
SchemaType.INTEGER, SchemaType.NUMBER, SchemaType.STRING, 'integer', 'number', 'string',
] ]
def unmarshal(self, value=NoValue): def unmarshal(self, value=NoValue):
@ -272,9 +281,11 @@ class AnyUnmarshaller(ComplexUnmarshaller):
return value return value
def _get_one_of_schema(self, value): def _get_one_of_schema(self, value):
if not self.schema.one_of: if 'oneOf' not in self.schema:
return return
for subschema in self.schema.one_of:
one_of_schemas = self.schema / 'oneOf'
for subschema in one_of_schemas:
unmarshaller = self.unmarshallers_factory.create(subschema) unmarshaller = self.unmarshallers_factory.create(subschema)
try: try:
unmarshaller.validate(value) unmarshaller.validate(value)
@ -284,10 +295,12 @@ class AnyUnmarshaller(ComplexUnmarshaller):
return subschema return subschema
def _get_all_of_schema(self, value): def _get_all_of_schema(self, value):
if not self.schema.all_of: if 'allOf' not in self.schema:
return return
for subschema in self.schema.all_of:
if subschema.type == SchemaType.ANY: all_of_schemas = self.schema / 'allOf'
for subschema in all_of_schemas:
if 'type' not in subschema:
continue continue
unmarshaller = self.unmarshallers_factory.create(subschema) unmarshaller = self.unmarshallers_factory.create(subschema)
try: try:

View file

@ -9,6 +9,7 @@ from openapi_core.schema.parameters.exceptions import (
) )
from openapi_core.schema.request_bodies.exceptions import MissingRequestBody from openapi_core.schema.request_bodies.exceptions import MissingRequestBody
from openapi_core.security.exceptions import SecurityError from openapi_core.security.exceptions import SecurityError
from openapi_core.spec.parameters import get_aslist, get_explode
from openapi_core.templating.media_types.exceptions import MediaTypeFinderError from openapi_core.templating.media_types.exceptions import MediaTypeFinderError
from openapi_core.templating.paths.exceptions import PathError from openapi_core.templating.paths.exceptions import PathError
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
@ -38,10 +39,17 @@ class RequestValidator(BaseValidator):
request.parameters.path = request.parameters.path or \ request.parameters.path = request.parameters.path or \
path_result.variables path_result.variables
operation_params = operation.get('parameters', [])
operation_params_iter = operation_params and \
iter(operation_params) or []
path_params = path.get('parameters', [])
params_params_iter = path_params and \
iter(path_params) or []
params, params_errors = self._get_parameters( params, params_errors = self._get_parameters(
request, chain( request, chain(
iteritems(operation.parameters), operation_params_iter,
iteritems(path.parameters) params_params_iter,
) )
) )
@ -63,10 +71,17 @@ class RequestValidator(BaseValidator):
request.parameters.path = request.parameters.path or \ request.parameters.path = request.parameters.path or \
path_result.variables path_result.variables
operation_params = operation.get('parameters', [])
operation_params_iter = operation_params and \
iter(operation_params) or []
path_params = path.get('parameters', [])
params_params_iter = path_params and \
iter(path_params) or []
params, params_errors = self._get_parameters( params, params_errors = self._get_parameters(
request, chain( request, chain(
iteritems(operation.parameters), operation_params_iter,
iteritems(path.parameters) params_params_iter,
) )
) )
return RequestValidationResult( return RequestValidationResult(
@ -87,9 +102,11 @@ class RequestValidator(BaseValidator):
) )
def _get_security(self, request, operation): def _get_security(self, request, operation):
security = self.spec.security security = None
if operation.security is not None: if 'security' in self.spec:
security = operation.security security = self.spec / 'security'
if 'security' in operation:
security = operation / 'security'
if not security: if not security:
return {} return {}
@ -99,7 +116,7 @@ class RequestValidator(BaseValidator):
return { return {
scheme_name: self._get_security_value( scheme_name: self._get_security_value(
scheme_name, request) scheme_name, request)
for scheme_name in security_requirement for scheme_name in security_requirement.keys()
} }
except SecurityError: except SecurityError:
continue continue
@ -110,21 +127,26 @@ class RequestValidator(BaseValidator):
errors = [] errors = []
seen = set() seen = set()
locations = {} locations = {}
for param_name, param in params: for param in params:
if (param_name, param.location.value) in seen: param_name = param['name']
param_location = param['in']
if (param_name, param_location) in seen:
# skip parameter already seen # skip parameter already seen
# e.g. overriden path item paremeter on operation # e.g. overriden path item paremeter on operation
continue continue
seen.add((param_name, param.location.value)) seen.add((param_name, param_location))
try: try:
raw_value = self._get_parameter_value(param, request) raw_value = self._get_parameter_value(param, request)
except MissingRequiredParameter as exc: except MissingRequiredParameter as exc:
errors.append(exc) errors.append(exc)
continue continue
except MissingParameter: except MissingParameter:
if not param.schema or not param.schema.has_default(): if 'schema' not in param:
continue continue
casted = param.schema.default schema = param / 'schema'
if 'default' not in schema:
continue
casted = schema['default']
else: else:
try: try:
deserialised = self._deserialise_parameter( deserialised = self._deserialise_parameter(
@ -144,28 +166,29 @@ class RequestValidator(BaseValidator):
except (ValidateError, UnmarshalError) as exc: except (ValidateError, UnmarshalError) as exc:
errors.append(exc) errors.append(exc)
else: else:
locations.setdefault(param.location.value, {}) locations.setdefault(param_location, {})
locations[param.location.value][param_name] = unmarshalled locations[param_location][param_name] = unmarshalled
return RequestParameters(**locations), errors return RequestParameters(**locations), errors
def _get_body(self, request, operation): def _get_body(self, request, operation):
if operation.request_body is None: if not 'requestBody' in operation:
return None, [] return None, []
request_body = operation / 'requestBody'
try: try:
media_type = self._get_media_type( media_type, mimetype = self._get_media_type(
operation.request_body.content, request) request_body / 'content', request)
except MediaTypeFinderError as exc: except MediaTypeFinderError as exc:
return None, [exc, ] return None, [exc, ]
try: try:
raw_body = self._get_body_value(operation.request_body, request) raw_body = self._get_body_value(request_body, request)
except MissingRequestBody as exc: except MissingRequestBody as exc:
return None, [exc, ] return None, [exc, ]
try: try:
deserialised = self._deserialise_media_type(media_type, raw_body) deserialised = self._deserialise_data(mimetype, raw_body)
except DeserializeError as exc: except DeserializeError as exc:
return None, [exc, ] return None, [exc, ]
@ -182,33 +205,37 @@ class RequestValidator(BaseValidator):
return body, [] return body, []
def _get_security_value(self, scheme_name, request): def _get_security_value(self, scheme_name, request):
scheme = self.spec.components.security_schemes.get(scheme_name) security_schemes = self.spec / 'components#securitySchemes'
if not scheme: if scheme_name not in security_schemes:
return return
scheme = security_schemes[scheme_name]
from openapi_core.security.factories import SecurityProviderFactory from openapi_core.security.factories import SecurityProviderFactory
security_provider_factory = SecurityProviderFactory() security_provider_factory = SecurityProviderFactory()
security_provider = security_provider_factory.create(scheme) security_provider = security_provider_factory.create(scheme)
return security_provider(request) return security_provider(request)
def _get_parameter_value(self, param, request): def _get_parameter_value(self, param, request):
location = request.parameters[param.location.value] param_location = param['in']
location = request.parameters[param_location]
if param.name not in location: if param['name'] not in location:
if param.required: if param.getkey('required', False):
raise MissingRequiredParameter(param.name) raise MissingRequiredParameter(param['name'])
raise MissingParameter(param.name) raise MissingParameter(param['name'])
if param.aslist and param.explode: aslist = get_aslist(param)
explode = get_explode(param)
if aslist and explode:
if hasattr(location, 'getall'): if hasattr(location, 'getall'):
return location.getall(param.name) return location.getall(param['name'])
return location.getlist(param.name) return location.getlist(param['name'])
return location[param.name] return location[param['name']]
def _get_body_value(self, request_body, request): def _get_body_value(self, request_body, request):
if not request.body and request_body.required: required = request_body.getkey('required', False)
if not request.body and required:
raise MissingRequestBody(request) raise MissingRequestBody(request)
return request.body return request.body

View file

@ -43,7 +43,7 @@ class ResponseValidator(BaseValidator):
def _get_operation_response(self, operation, response): def _get_operation_response(self, operation, response):
from openapi_core.templating.responses.finders import ResponseFinder from openapi_core.templating.responses.finders import ResponseFinder
finder = ResponseFinder(operation.responses) finder = ResponseFinder(operation / 'responses')
return finder.find(str(response.status_code)) return finder.find(str(response.status_code))
def _validate_data(self, request, response): def _validate_data(self, request, response):
@ -67,12 +67,12 @@ class ResponseValidator(BaseValidator):
) )
def _get_data(self, response, operation_response): def _get_data(self, response, operation_response):
if not operation_response.content: if 'content' not in operation_response:
return None, [] return None, []
try: try:
media_type = self._get_media_type( media_type, mimetype = self._get_media_type(
operation_response.content, response) operation_response / 'content', response)
except MediaTypeFinderError as exc: except MediaTypeFinderError as exc:
return None, [exc, ] return None, [exc, ]
@ -82,7 +82,7 @@ class ResponseValidator(BaseValidator):
return None, [exc, ] return None, [exc, ]
try: try:
deserialised = self._deserialise_media_type(media_type, raw_data) deserialised = self._deserialise_data(mimetype, raw_data)
except DeserializeError as exc: except DeserializeError as exc:
return None, [exc, ] return None, [exc, ]

View file

@ -26,36 +26,38 @@ class BaseValidator(object):
finder = MediaTypeFinder(content) finder = MediaTypeFinder(content)
return finder.find(request_or_response) return finder.find(request_or_response)
def _deserialise_media_type(self, media_type, value): def _deserialise_data(self, mimetype, value):
from openapi_core.deserializing.media_types.factories import ( from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory, MediaTypeDeserializersFactory,
) )
deserializers_factory = MediaTypeDeserializersFactory( deserializers_factory = MediaTypeDeserializersFactory(
self.custom_media_type_deserializers) self.custom_media_type_deserializers)
deserializer = deserializers_factory.create(media_type) deserializer = deserializers_factory.create(mimetype)
return deserializer(value) return deserializer(value)
def _cast(self, param_or_media_type, value): def _cast(self, param_or_media_type, value):
# return param_or_media_type.cast(value) # return param_or_media_type.cast(value)
if not param_or_media_type.schema: if not 'schema' in param_or_media_type:
return value return value
from openapi_core.casting.schemas.factories import SchemaCastersFactory from openapi_core.casting.schemas.factories import SchemaCastersFactory
casters_factory = SchemaCastersFactory() casters_factory = SchemaCastersFactory()
caster = casters_factory.create(param_or_media_type.schema) schema = param_or_media_type / 'schema'
caster = casters_factory.create(schema)
return caster(value) return caster(value)
def _unmarshal(self, param_or_media_type, value, context): def _unmarshal(self, param_or_media_type, value, context):
if not param_or_media_type.schema: if not 'schema' in param_or_media_type:
return value return value
from openapi_core.unmarshalling.schemas.factories import ( from openapi_core.unmarshalling.schemas.factories import (
SchemaUnmarshallersFactory, SchemaUnmarshallersFactory,
) )
spec_resolver = self.spec.accessor.dereferencer.resolver_manager.resolver
unmarshallers_factory = SchemaUnmarshallersFactory( unmarshallers_factory = SchemaUnmarshallersFactory(
self.spec._resolver, self.format_checker, spec_resolver, self.format_checker,
self.custom_formatters, context=context, self.custom_formatters, context=context,
) )
unmarshaller = unmarshallers_factory.create( schema = param_or_media_type / 'schema'
param_or_media_type.schema) unmarshaller = unmarshallers_factory.create(schema)
return unmarshaller(value) return unmarshaller(value)

View file

@ -38,6 +38,7 @@ paths:
default: 1 default: 1
- name: limit - name: limit
in: query in: query
style: form
description: How many items to return at one time (max 100) description: How many items to return at one time (max 100)
required: true required: true
schema: schema:

View file

@ -6,31 +6,32 @@ class TestLinkSpec(object):
def test_no_param(self, factory): def test_no_param(self, factory):
spec_dict = factory.spec_from_file("data/v3.0/links.yaml") spec_dict = factory.spec_from_file("data/v3.0/links.yaml")
spec = create_spec(spec_dict) spec = create_spec(spec_dict)
resp = spec['/status']['get'].responses['default'] resp = spec / 'paths#/status#get#responses#default'
assert len(resp.links) == 1 links = resp / 'links'
assert len(links) == 1
link = resp.links['noParamLink'] link = links / 'noParamLink'
assert link['operationId'] == 'noParOp'
assert link.operationId == 'noParOp' assert 'server' not in link
assert link.server is None assert 'requestBody' not in link
assert link.request_body is None assert 'parameters' not in link
assert len(link.parameters) == 0
def test_param(self, factory): def test_param(self, factory):
spec_dict = factory.spec_from_file("data/v3.0/links.yaml") spec_dict = factory.spec_from_file("data/v3.0/links.yaml")
spec = create_spec(spec_dict) spec = create_spec(spec_dict)
resp = spec['/status/{resourceId}']['get'].responses['default'] resp = spec / 'paths#/status/{resourceId}#get#responses#default'
assert len(resp.links) == 1 links = resp / 'links'
assert len(links) == 1
link = resp.links['paramLink'] link = links / 'paramLink'
assert link['operationId'] == 'paramOp'
assert 'server' not in link
assert link['requestBody'] == 'test'
assert link.operationId == 'paramOp' parameters = link['parameters']
assert link.server is None assert len(parameters) == 1
assert link.request_body == 'test'
assert len(link.parameters) == 1
param = link.parameters['opParam']
param = parameters['opParam']
assert param == '$request.path.resourceId' assert param == '$request.path.resourceId'

View file

@ -1,6 +1,5 @@
import pytest import pytest
from openapi_core.schema.parameters.enums import ParameterLocation
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
@ -15,9 +14,12 @@ class TestMinimal(object):
spec_dict = factory.spec_from_file(spec_path) spec_dict = factory.spec_from_file(spec_path)
spec = create_spec(spec_dict) spec = create_spec(spec_dict)
path = spec['/resource/{resId}'] path = spec / 'paths#/resource/{resId}'
assert len(path.parameters) == 1 parameters = path / 'parameters'
param = path.parameters['resId'] assert len(parameters) == 1
assert param.required
assert param.location == ParameterLocation.PATH param = parameters[0]
assert param['name'] == 'resId'
assert param['required']
assert param['in'] == 'path'

View file

@ -2,18 +2,9 @@ import pytest
from base64 import b64encode from base64 import b64encode
from six import iteritems, text_type from six import iteritems, text_type
from openapi_core.schema.media_types.models import MediaType
from openapi_core.schema.operations.models import Operation
from openapi_core.schema.parameters.models import Parameter
from openapi_core.schema.paths.models import Path
from openapi_core.schema.request_bodies.models import RequestBody
from openapi_core.schema.responses.models import Response
from openapi_core.schema.schemas.models import Schema
from openapi_core.schema.security_requirements.models import (
SecurityRequirement,
)
from openapi_core.schema.servers.models import Server, ServerVariable
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
from openapi_core.spec.servers import get_server_url
from openapi_core.spec.specs import get_spec_url
from openapi_core.validation.request.validators import RequestValidator from openapi_core.validation.request.validators import RequestValidator
from openapi_core.validation.response.validators import ResponseValidator from openapi_core.validation.response.validators import ResponseValidator
@ -51,123 +42,117 @@ class TestPetstore(object):
def test_spec(self, spec, spec_dict): def test_spec(self, spec, spec_dict):
url = 'http://petstore.swagger.io/v1' url = 'http://petstore.swagger.io/v1'
info = spec / 'info'
info_spec = spec_dict['info'] info_spec = spec_dict['info']
assert spec.info.title == info_spec['title'] assert info['title'] == info_spec['title']
assert spec.info.description == info_spec['description'] assert info['description'] == info_spec['description']
assert spec.info.terms_of_service == info_spec['termsOfService'] assert info['termsOfService'] == info_spec['termsOfService']
assert spec.info.version == info_spec['version'] assert info['version'] == info_spec['version']
contact = info / 'contact'
contact_spec = info_spec['contact'] contact_spec = info_spec['contact']
assert spec.info.contact.name == contact_spec['name'] assert contact['name'] == contact_spec['name']
assert spec.info.contact.url == contact_spec['url'] assert contact['url'] == contact_spec['url']
assert spec.info.contact.email == contact_spec['email'] assert contact['email'] == contact_spec['email']
license = info / 'license'
license_spec = info_spec['license'] license_spec = info_spec['license']
assert spec.info.license.name == license_spec['name'] assert license['name'] == license_spec['name']
assert spec.info.license.url == license_spec['url'] assert license['url'] == license_spec['url']
security = spec / 'security'
security_spec = spec_dict.get('security', []) security_spec = spec_dict.get('security', [])
for idx, security_req in enumerate(spec.security): for idx, security_reqs in enumerate(security):
assert type(security_req) == SecurityRequirement security_reqs_spec = security_spec[idx]
for scheme_name, security_req in iteritems(security_reqs):
security_req == security_reqs_spec[scheme_name]
security_req_spec = security_spec[idx] assert get_spec_url(spec) == url
for scheme_name in security_req:
security_req[scheme_name] == security_req_spec[scheme_name]
assert spec.get_server_url() == url
for idx, server in enumerate(spec.servers):
assert type(server) == Server
servers = spec / 'servers'
for idx, server in enumerate(servers):
server_spec = spec_dict['servers'][idx] server_spec = spec_dict['servers'][idx]
assert server.url == server_spec['url'] assert server['url'] == server_spec['url']
assert server.default_url == url assert get_server_url(server) == url
for variable_name, variable in iteritems(server.variables):
assert type(variable) == ServerVariable
assert variable.name == variable_name
variables = server / 'variables'
for variable_name, variable in iteritems(variables):
variable_spec = server_spec['variables'][variable_name] variable_spec = server_spec['variables'][variable_name]
assert variable.default == variable_spec['default'] assert variable['default'] == variable_spec['default']
assert variable.enum == variable_spec.get('enum') assert variable['enum'] == variable_spec.get('enum')
for path_name, path in iteritems(spec.paths):
assert type(path) == Path
paths = spec / 'paths'
for path_name, path in iteritems(paths):
path_spec = spec_dict['paths'][path_name] path_spec = spec_dict['paths'][path_name]
assert path.name == path_name assert path.getkey('summary') == path_spec.get('summary')
assert path.summary == path_spec.get('summary') assert path.getkey('description') == path_spec.get('description')
assert path.description == path_spec.get('description')
servers = path.get('servers', [])
servers_spec = path_spec.get('servers', []) servers_spec = path_spec.get('servers', [])
for idx, server in enumerate(path.servers): for idx, server in enumerate(servers):
assert type(server) == Server
server_spec = servers_spec[idx] server_spec = servers_spec[idx]
assert server.url == server_spec['url'] assert server.url == server_spec['url']
assert server.default_url == server_spec['url'] assert server.default_url == server_spec['url']
assert server.description == server_spec.get('description') assert server.description == server_spec.get('description')
for variable_name, variable in iteritems(server.variables): variables = server.get('variables', {})
assert type(variable) == ServerVariable for variable_name, variable in iteritems(variables):
assert variable.name == variable_name
variable_spec = server_spec['variables'][variable_name] variable_spec = server_spec['variables'][variable_name]
assert variable.default == variable_spec['default'] assert variable['default'] == variable_spec['default']
assert variable.enum == variable_spec.get('enum') assert variable.getkey('enum') == variable_spec.get('enum')
for http_method, operation in iteritems(path.operations): operations = [
'get', 'put', 'post', 'delete', 'options',
'head', 'patch', 'trace',
]
for http_method in operations:
if http_method not in path:
continue
operation = path / http_method
operation_spec = path_spec[http_method] operation_spec = path_spec[http_method]
assert type(operation) == Operation assert operation['operationId'] is not None
assert operation.path_name == path_name assert operation['tags'] == operation_spec['tags']
assert operation.http_method == http_method assert operation['summary'] == operation_spec.get('summary')
assert operation.operation_id is not None assert operation.getkey('description') == operation_spec.get(
assert operation.tags == operation_spec['tags']
assert operation.summary == operation_spec.get('summary')
assert operation.description == operation_spec.get(
'description') 'description')
ext_docs = operation.get('externalDocs')
ext_docs_spec = operation_spec.get('externalDocs') ext_docs_spec = operation_spec.get('externalDocs')
assert bool(ext_docs_spec) == bool(ext_docs)
if ext_docs_spec: if ext_docs_spec:
ext_docs = operation.external_docs assert ext_docs['url'] == ext_docs_spec['url']
assert ext_docs.url == ext_docs_spec['url'] assert ext_docs.getkey('description') == ext_docs_spec.get(
assert ext_docs.description == ext_docs_spec.get(
'description') 'description')
servers = operation.get('servers', [])
servers_spec = operation_spec.get('servers', []) servers_spec = operation_spec.get('servers', [])
for idx, server in enumerate(operation.servers): for idx, server in enumerate(servers):
assert type(server) == Server
server_spec = servers_spec[idx] server_spec = servers_spec[idx]
assert server.url == server_spec['url'] assert server['url'] == server_spec['url']
assert server.default_url == server_spec['url'] assert get_server_url(server) == server_spec['url']
assert server.description == server_spec.get('description') assert server['description'] == server_spec.get(
'description')
for variable_name, variable in iteritems(server.variables):
assert type(variable) == ServerVariable
assert variable.name == variable_name
variables = server.get('variables', {})
for variable_name, variable in iteritems(variables):
variable_spec = server_spec['variables'][variable_name] variable_spec = server_spec['variables'][variable_name]
assert variable.default == variable_spec['default'] assert variable['default'] == variable_spec['default']
assert variable.enum == variable_spec.get('enum') assert variable.getkey('enum') == variable_spec.get(
'enum')
security = operation.get('security', [])
security_spec = operation_spec.get('security') security_spec = operation_spec.get('security')
if security_spec is not None: if security_spec is not None:
for idx, security_req in enumerate(operation.security): for idx, security_reqs in enumerate(security):
assert type(security_req) == SecurityRequirement security_reqs_spec = security_spec[idx]
for scheme_name, security_req in iteritems(
security_req_spec = security_spec[idx] security_reqs):
for scheme_name in security_req: security_req == security_reqs_spec[scheme_name]
security_req[scheme_name] == security_req_spec[
scheme_name]
responses = operation / 'responses'
responses_spec = operation_spec.get('responses') responses_spec = operation_spec.get('responses')
for http_status, response in iteritems(responses):
for http_status, response in iteritems(operation.responses):
assert type(response) == Response
assert response.http_status == http_status
response_spec = responses_spec[http_status] response_spec = responses_spec[http_status]
if not response_spec: if not response_spec:
@ -179,17 +164,16 @@ class TestPetstore(object):
description_spec = response_spec['description'] description_spec = response_spec['description']
assert response.description == description_spec assert response.getkey('description') == description_spec
for parameter_name, parameter in iteritems(
response.headers):
assert type(parameter) == Parameter
assert parameter.name == parameter_name
headers = response.get('headers', {})
for parameter_name, parameter in iteritems(headers):
headers_spec = response_spec['headers'] headers_spec = response_spec['headers']
parameter_spec = headers_spec[parameter_name] parameter_spec = headers_spec[parameter_name]
schema = parameter.get('schema')
schema_spec = parameter_spec.get('schema') schema_spec = parameter_spec.get('schema')
assert bool(schema_spec) == bool(parameter.schema) assert bool(schema_spec) == bool(schema)
if not schema_spec: if not schema_spec:
continue continue
@ -198,13 +182,12 @@ class TestPetstore(object):
if '$ref' in schema_spec: if '$ref' in schema_spec:
continue continue
assert type(parameter.schema) == Schema assert schema['type'] ==\
assert parameter.schema.type.value ==\
schema_spec['type'] schema_spec['type']
assert parameter.schema.format ==\ assert schema.getkey('format') ==\
schema_spec.get('format') schema_spec.get('format')
assert parameter.schema.required == schema_spec.get( assert schema.getkey('required') == schema_spec.get(
'required', []) 'required')
content_spec = parameter_spec.get('content') content_spec = parameter_spec.get('content')
assert bool(content_spec) == bool(parameter.content) assert bool(content_spec) == bool(parameter.content)
@ -212,14 +195,12 @@ class TestPetstore(object):
if not content_spec: if not content_spec:
continue continue
for mimetype, media_type in iteritems( content = parameter.get('content', {})
parameter.content): for mimetype, media_type in iteritems(content):
assert type(media_type) == MediaType
assert media_type.mimetype == mimetype
media_spec = parameter_spec['content'][mimetype] media_spec = parameter_spec['content'][mimetype]
schema = media_type.get('schema')
schema_spec = media_spec.get('schema') schema_spec = media_spec.get('schema')
assert bool(schema_spec) == bool(media_type.schema) assert bool(schema_spec) == bool(schema)
if not schema_spec: if not schema_spec:
continue continue
@ -228,30 +209,28 @@ class TestPetstore(object):
if '$ref' in schema_spec: if '$ref' in schema_spec:
continue continue
assert type(media_type.schema) == Schema assert schema['type'] ==\
assert media_type.schema.type.value ==\
schema_spec['type'] schema_spec['type']
assert media_type.schema.format ==\ assert schema.getkey('format') ==\
schema_spec.get('format') schema_spec.get('format')
assert media_type.schema.required == \ assert schema.getkey('required') == \
schema_spec.get('required', False) schema_spec.get('required')
content_spec = response_spec.get('content') content_spec = response_spec.get('content')
if not content_spec: if not content_spec:
continue continue
for mimetype, media_type in iteritems(response.content): content = response.get('content', {})
assert type(media_type) == MediaType for mimetype, media_type in iteritems(content):
assert media_type.mimetype == mimetype
content_spec = response_spec['content'][mimetype] content_spec = response_spec['content'][mimetype]
example_spec = content_spec.get('example') example_spec = content_spec.get('example')
assert media_type.example == example_spec assert media_type.getkey('example') == example_spec
schema = media_type.get('schema')
schema_spec = content_spec.get('schema') schema_spec = content_spec.get('schema')
assert bool(schema_spec) == bool(media_type.schema) assert bool(schema_spec) == bool(schema)
if not schema_spec: if not schema_spec:
continue continue
@ -260,31 +239,24 @@ class TestPetstore(object):
if '$ref' in schema_spec: if '$ref' in schema_spec:
continue continue
assert type(media_type.schema) == Schema assert schema['type'] == schema_spec['type']
assert media_type.schema.type.value ==\ assert schema.getkey('required') == schema_spec.get(
schema_spec['type'] 'required')
assert media_type.schema.required == schema_spec.get(
'required', [])
request_body = operation.get('requestBody')
request_body_spec = operation_spec.get('requestBody') request_body_spec = operation_spec.get('requestBody')
assert bool(request_body_spec) == bool(request_body)
assert bool(request_body_spec) == bool(operation.request_body)
if not request_body_spec: if not request_body_spec:
continue continue
assert type(operation.request_body) == RequestBody assert bool(request_body.getkey('required')) ==\
assert bool(operation.request_body.required) ==\ request_body_spec.get('required')
request_body_spec.get('required', False)
for mimetype, media_type in iteritems(
operation.request_body.content):
assert type(media_type) == MediaType
assert media_type.mimetype == mimetype
content = request_body / 'content'
for mimetype, media_type in iteritems(content):
content_spec = request_body_spec['content'][mimetype] content_spec = request_body_spec['content'][mimetype]
schema_spec = content_spec.get('schema') schema_spec = content_spec.get('schema')
assert bool(schema_spec) == bool(media_type.schema)
if not schema_spec: if not schema_spec:
continue continue
@ -293,20 +265,22 @@ class TestPetstore(object):
if '$ref' in schema_spec: if '$ref' in schema_spec:
continue continue
assert type(media_type.schema) == Schema schema = content.get('schema')
assert media_type.schema.type.value ==\ assert bool(schema_spec) == bool(schema)
assert schema.type.value ==\
schema_spec['type'] schema_spec['type']
assert media_type.schema.format ==\ assert schema.format ==\
schema_spec.get('format') schema_spec.get('format')
assert media_type.schema.required == schema_spec.get( assert schema.required == schema_spec.get(
'required', False) 'required', False)
if not spec.components: components = spec.get('components')
if not components:
return return
for schema_name, schema in iteritems(spec.components.schemas): schemas = components.get('schemas', {})
assert type(schema) == Schema for schema_name, schema in iteritems(schemas):
schema_spec = spec_dict['components']['schemas'][schema_name] schema_spec = spec_dict['components']['schemas'][schema_name]
assert schema.read_only == schema_spec.get('readOnly', False) assert schema.getkey('readOnly') == schema_spec.get('readOnly')
assert schema.write_only == schema_spec.get('writeOnly', False) assert schema.getkey('writeOnly') == schema_spec.get('writeOnly')

View file

@ -15,7 +15,6 @@ from openapi_core.extensions.models.models import BaseModel
from openapi_core.schema.parameters.exceptions import ( from openapi_core.schema.parameters.exceptions import (
MissingRequiredParameter, MissingRequiredParameter,
) )
from openapi_core.schema.schemas.enums import SchemaType
from openapi_core.shortcuts import ( from openapi_core.shortcuts import (
create_spec, validate_parameters, validate_body, validate_data, create_spec, validate_parameters, validate_body, validate_data,
) )
@ -187,7 +186,7 @@ class TestPetstore(object):
schema_errors = response_result.errors[0].schema_errors schema_errors = response_result.errors[0].schema_errors
assert response_result.errors == [ assert response_result.errors == [
InvalidSchemaValue( InvalidSchemaValue(
type=SchemaType.OBJECT, type='object',
value=response_data_json, value=response_data_json,
schema_errors=schema_errors, schema_errors=schema_errors,
), ),

View file

@ -6,7 +6,6 @@ from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.deserializing.media_types.factories import ( from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory, MediaTypeDeserializersFactory,
) )
from openapi_core.schema.media_types.models import MediaType
class TestMediaTypeDeserializer(object): class TestMediaTypeDeserializer(object):
@ -19,46 +18,46 @@ class TestMediaTypeDeserializer(object):
return create_deserializer return create_deserializer
def test_json_empty(self, deserializer_factory): def test_json_empty(self, deserializer_factory):
media_type = MediaType('application/json') mimetype = 'application/json'
value = '' value = ''
with pytest.raises(DeserializeError): with pytest.raises(DeserializeError):
deserializer_factory(media_type)(value) deserializer_factory(mimetype)(value)
def test_json_empty_object(self, deserializer_factory): def test_json_empty_object(self, deserializer_factory):
media_type = MediaType('application/json') mimetype = 'application/json'
value = "{}" value = "{}"
result = deserializer_factory(media_type)(value) result = deserializer_factory(mimetype)(value)
assert result == {} assert result == {}
def test_urlencoded_form_empty(self, deserializer_factory): def test_urlencoded_form_empty(self, deserializer_factory):
media_type = MediaType('application/x-www-form-urlencoded') mimetype = 'application/x-www-form-urlencoded'
value = '' value = ''
result = deserializer_factory(media_type)(value) result = deserializer_factory(mimetype)(value)
assert result == {} assert result == {}
def test_urlencoded_form_simple(self, deserializer_factory): def test_urlencoded_form_simple(self, deserializer_factory):
media_type = MediaType('application/x-www-form-urlencoded') mimetype = 'application/x-www-form-urlencoded'
value = 'param1=test' value = 'param1=test'
result = deserializer_factory(media_type)(value) result = deserializer_factory(mimetype)(value)
assert result == {'param1': 'test'} assert result == {'param1': 'test'}
@pytest.mark.parametrize('value', [b(''), u('')]) @pytest.mark.parametrize('value', [b(''), u('')])
def test_data_form_empty(self, deserializer_factory, value): def test_data_form_empty(self, deserializer_factory, value):
media_type = MediaType('multipart/form-data') mimetype = 'multipart/form-data'
result = deserializer_factory(media_type)(value) result = deserializer_factory(mimetype)(value)
assert result == {} assert result == {}
def test_data_form_simple(self, deserializer_factory): def test_data_form_simple(self, deserializer_factory):
media_type = MediaType('multipart/form-data') mimetype = 'multipart/form-data'
value = b( value = b(
'Content-Type: multipart/form-data; boundary="' 'Content-Type: multipart/form-data; boundary="'
'===============2872712225071193122=="\n' '===============2872712225071193122=="\n'
@ -69,13 +68,12 @@ class TestMediaTypeDeserializer(object):
'--===============2872712225071193122==--\n' '--===============2872712225071193122==--\n'
) )
result = deserializer_factory(media_type)(value) result = deserializer_factory(mimetype)(value)
assert result == {'param1': b('test')} assert result == {'param1': b('test')}
def test_custom_simple(self, deserializer_factory): def test_custom_simple(self, deserializer_factory):
custom_mimetype = 'application/custom' custom_mimetype = 'application/custom'
media_type = MediaType(custom_mimetype)
value = "{}" value = "{}"
def custom_deserializer(value): def custom_deserializer(value):
@ -85,6 +83,6 @@ class TestMediaTypeDeserializer(object):
} }
result = deserializer_factory( result = deserializer_factory(
media_type, custom_deserializers=custom_deserializers)(value) custom_mimetype, custom_deserializers=custom_deserializers)(value)
assert result == 'custom' assert result == 'custom'

View file

@ -6,7 +6,7 @@ from openapi_core.deserializing.parameters.factories import (
from openapi_core.deserializing.parameters.exceptions import ( from openapi_core.deserializing.parameters.exceptions import (
EmptyParameterValue, EmptyParameterValue,
) )
from openapi_core.schema.parameters.models import Parameter from openapi_core.spec.paths import SpecPath
class TestParameterDeserializer(object): class TestParameterDeserializer(object):
@ -18,7 +18,12 @@ class TestParameterDeserializer(object):
return create_deserializer return create_deserializer
def test_deprecated(self, deserializer_factory): def test_deprecated(self, deserializer_factory):
param = Parameter('param', 'query', deprecated=True) spec = {
'name': 'param',
'in': 'query',
'deprecated': True,
}
param = SpecPath.from_spec(spec)
value = 'test' value = 'test'
with pytest.warns(DeprecationWarning): with pytest.warns(DeprecationWarning):
@ -27,14 +32,22 @@ class TestParameterDeserializer(object):
assert result == value assert result == value
def test_query_empty(self, deserializer_factory): def test_query_empty(self, deserializer_factory):
param = Parameter('param', 'query') spec = {
'name': 'param',
'in': 'query',
}
param = SpecPath.from_spec(spec)
value = '' value = ''
with pytest.raises(EmptyParameterValue): with pytest.raises(EmptyParameterValue):
deserializer_factory(param)(value) deserializer_factory(param)(value)
def test_query_valid(self, deserializer_factory): def test_query_valid(self, deserializer_factory):
param = Parameter('param', 'query') spec = {
'name': 'param',
'in': 'query',
}
param = SpecPath.from_spec(spec)
value = 'test' value = 'test'
result = deserializer_factory(param)(value) result = deserializer_factory(param)(value)

View file

@ -1,15 +1,22 @@
import pytest import pytest
from openapi_core.schema.security_schemes.models import SecurityScheme
from openapi_core.security.providers import HttpProvider from openapi_core.security.providers import HttpProvider
from openapi_core.spec.paths import SpecPath
from openapi_core.testing import MockRequest from openapi_core.testing import MockRequest
class TestHttpProvider(object): class TestHttpProvider(object):
@pytest.fixture @pytest.fixture
def scheme(self): def spec(self):
return SecurityScheme('http', scheme='bearer') return {
'type': 'http',
'scheme': 'bearer',
}
@pytest.fixture
def scheme(self, spec):
return SpecPath.from_spec(spec)
@pytest.fixture @pytest.fixture
def provider(self, scheme): def provider(self, scheme):

View file

@ -1,11 +1,6 @@
import pytest import pytest
from openapi_core.schema.infos.models import Info from openapi_core.spec.paths import SpecPath
from openapi_core.schema.operations.models import Operation
from openapi_core.schema.parameters.models import Parameter
from openapi_core.schema.paths.models import Path
from openapi_core.schema.servers.models import Server, ServerVariable
from openapi_core.schema.specs.models import Spec
from openapi_core.templating.datatypes import TemplateResult from openapi_core.templating.datatypes import TemplateResult
from openapi_core.templating.paths.exceptions import ( from openapi_core.templating.paths.exceptions import (
PathNotFound, OperationNotFound, ServerNotFound, PathNotFound, OperationNotFound, ServerNotFound,
@ -19,8 +14,25 @@ class BaseTestSimpleServer(object):
server_url = 'http://petstore.swagger.io' server_url = 'http://petstore.swagger.io'
@pytest.fixture @pytest.fixture
def server(self): def server_variable(self):
return Server(self.server_url, {}) return {}
@pytest.fixture
def server_variables(self, server_variable):
if not server_variable:
return {}
return {
self.server_variable_name: server_variable,
}
@pytest.fixture
def server(self, server_variables):
server = {
'url': self.server_url,
}
if server_variables:
server['variables'] = server_variables
return server
@pytest.fixture @pytest.fixture
def servers(self, server): def servers(self, server):
@ -36,22 +48,13 @@ class BaseTestVariableServer(BaseTestSimpleServer):
@pytest.fixture @pytest.fixture
def server_variable(self): def server_variable(self):
return ServerVariable(
self.server_variable_name,
default=self.server_variable_default,
enum=self.server_variable_enum,
)
@pytest.fixture
def server_variables(self, server_variable):
return { return {
self.server_variable_name: server_variable, self.server_variable_name: {
'default': self.server_variable_default,
'enum': self.server_variable_enum,
}
} }
@pytest.fixture
def server(self, server_variables):
return Server(self.server_url, server_variables)
class BaseTestSimplePath(object): class BaseTestSimplePath(object):
@ -59,7 +62,7 @@ class BaseTestSimplePath(object):
@pytest.fixture @pytest.fixture
def path(self, operations): def path(self, operations):
return Path(self.path_name, operations) return operations
@pytest.fixture @pytest.fixture
def paths(self, path): def paths(self, path):
@ -75,28 +78,38 @@ class BaseTestVariablePath(BaseTestSimplePath):
@pytest.fixture @pytest.fixture
def parameter(self): def parameter(self):
return Parameter(self.path_parameter_name, 'path')
@pytest.fixture
def parameters(self, parameter):
return { return {
self.path_parameter_name: parameter 'name': self.path_parameter_name,
'in': 'path',
} }
@pytest.fixture
def parameters(self, parameter):
return [parameter, ]
@pytest.fixture @pytest.fixture
def path(self, operations, parameters): def path(self, operations, parameters):
return Path(self.path_name, operations, parameters=parameters) path = operations.copy()
path['parameters'] = parameters
return path
class BaseTestSpecServer(object): class BaseTestSpecServer(object):
location = 'spec'
@pytest.fixture @pytest.fixture
def info(self): def info(self):
return Info('Test schema', '1.0') return {
'title': 'Test schema',
'version': '1.0',
}
@pytest.fixture @pytest.fixture
def operation(self): def operation(self):
return Operation('get', self.path_name, {}, {}) return {
'responses': [],
}
@pytest.fixture @pytest.fixture
def operations(self, operation): def operations(self, operation):
@ -106,7 +119,12 @@ class BaseTestSpecServer(object):
@pytest.fixture @pytest.fixture
def spec(self, info, paths, servers): def spec(self, info, paths, servers):
return Spec(info, paths, servers) spec = {
'info': info,
'servers': servers,
'paths': paths,
}
return SpecPath.from_spec(spec)
@pytest.fixture @pytest.fixture
def finder(self, spec): def finder(self, spec):
@ -115,24 +133,41 @@ class BaseTestSpecServer(object):
class BaseTestPathServer(BaseTestSpecServer): class BaseTestPathServer(BaseTestSpecServer):
location = 'path'
@pytest.fixture @pytest.fixture
def path(self, operations, servers): def path(self, operations, servers):
return Path(self.path_name, operations, servers=servers) path = operations.copy()
path['servers'] = servers
return path
@pytest.fixture @pytest.fixture
def spec(self, info, paths): def spec(self, info, paths):
return Spec(info, paths) spec = {
'info': info,
'paths': paths,
}
return SpecPath.from_spec(spec)
class BaseTestOperationServer(BaseTestSpecServer): class BaseTestOperationServer(BaseTestSpecServer):
location = 'operation'
@pytest.fixture @pytest.fixture
def operation(self, servers): def operation(self, servers):
return Operation('get', self.path_name, {}, {}, servers=servers) return {
'responses': [],
'servers': servers,
}
@pytest.fixture @pytest.fixture
def spec(self, info, paths): def spec(self, info, paths):
return Spec(info, paths) spec = {
'info': info,
'paths': paths,
}
return SpecPath.from_spec(spec)
class BaseTestServerNotFound(object): class BaseTestServerNotFound(object):
@ -141,6 +176,7 @@ class BaseTestServerNotFound(object):
def servers(self): def servers(self):
return [] return []
@pytest.mark.xfail(reason="returns default server")
def test_raises(self, finder): def test_raises(self, finder):
request_uri = '/resource' request_uri = '/resource'
request = MockRequest( request = MockRequest(
@ -167,13 +203,17 @@ class BaseTestOperationNotFound(object):
class BaseTestValid(object): class BaseTestValid(object):
def test_simple(self, finder, path, operation, server): def test_simple(self, finder, spec):
request_uri = '/resource' request_uri = '/resource'
method = 'get'
request = MockRequest( request = MockRequest(
'http://petstore.swagger.io', 'get', request_uri) 'http://petstore.swagger.io', method, request_uri)
result = finder.find(request) result = finder.find(request)
path = spec / 'paths' / self.path_name
operation = spec / 'paths' / self.path_name / method
server = eval(self.location) / 'servers' / 0
path_result = TemplateResult(self.path_name, {}) path_result = TemplateResult(self.path_name, {})
server_result = TemplateResult(self.server_url, {}) server_result = TemplateResult(self.server_url, {})
assert result == ( assert result == (
@ -184,13 +224,17 @@ class BaseTestValid(object):
class BaseTestVariableValid(object): class BaseTestVariableValid(object):
@pytest.mark.parametrize('version', ['v1', 'v2']) @pytest.mark.parametrize('version', ['v1', 'v2'])
def test_variable(self, finder, path, operation, server, version): def test_variable(self, finder, spec, version):
request_uri = '/{0}/resource'.format(version) request_uri = '/{0}/resource'.format(version)
method = 'get'
request = MockRequest( request = MockRequest(
'http://petstore.swagger.io', 'get', request_uri) 'http://petstore.swagger.io', method, request_uri)
result = finder.find(request) result = finder.find(request)
path = spec / 'paths' / self.path_name
operation = spec / 'paths' / self.path_name / method
server = eval(self.location) / 'servers' / 0
path_result = TemplateResult(self.path_name, {}) path_result = TemplateResult(self.path_name, {})
server_result = TemplateResult(self.server_url, {'version': version}) server_result = TemplateResult(self.server_url, {'version': version})
assert result == ( assert result == (
@ -201,13 +245,17 @@ class BaseTestVariableValid(object):
class BaseTestPathVariableValid(object): class BaseTestPathVariableValid(object):
@pytest.mark.parametrize('res_id', ['111', '222']) @pytest.mark.parametrize('res_id', ['111', '222'])
def test_path_variable(self, finder, path, operation, server, res_id): def test_path_variable(self, finder, spec, res_id):
request_uri = '/resource/{0}'.format(res_id) request_uri = '/resource/{0}'.format(res_id)
method = 'get'
request = MockRequest( request = MockRequest(
'http://petstore.swagger.io', 'get', request_uri) 'http://petstore.swagger.io', method, request_uri)
result = finder.find(request) result = finder.find(request)
path = spec / 'paths' / self.path_name
operation = spec / 'paths' / self.path_name / method
server = eval(self.location) / 'servers' / 0
path_result = TemplateResult(self.path_name, {'resource_id': res_id}) path_result = TemplateResult(self.path_name, {'resource_id': res_id})
server_result = TemplateResult(self.server_url, {}) server_result = TemplateResult(self.server_url, {})
assert result == ( assert result == (
@ -396,10 +444,13 @@ class TestSimilarPaths(
BaseTestSpecServer, BaseTestSimpleServer): BaseTestSpecServer, BaseTestSimpleServer):
path_name = '/tokens' path_name = '/tokens'
path_2_name = '/keys/{id}/tokens'
@pytest.fixture @pytest.fixture
def operation_2(self): def operation_2(self):
return Operation('get', '/keys/{id}/tokens', {}, {}) return {
'responses': [],
}
@pytest.fixture @pytest.fixture
def operations_2(self, operation_2): def operations_2(self, operation_2):
@ -409,28 +460,32 @@ class TestSimilarPaths(
@pytest.fixture @pytest.fixture
def path(self, operations): def path(self, operations):
return Path('/tokens', operations) return operations
@pytest.fixture @pytest.fixture
def path_2(self, operations_2): def path_2(self, operations_2):
return Path('/keys/{id}/tokens', operations_2) return operations_2
@pytest.fixture @pytest.fixture
def paths(self, path, path_2): def paths(self, path, path_2):
return { return {
path.name: path, self.path_name: path,
path_2.name: path_2, self.path_2_name: path_2,
} }
def test_valid(self, finder, path_2, operation_2, server): def test_valid(self, finder, spec):
token_id = '123' token_id = '123'
request_uri = '/keys/{0}/tokens'.format(token_id) request_uri = '/keys/{0}/tokens'.format(token_id)
method = 'get'
request = MockRequest( request = MockRequest(
'http://petstore.swagger.io', 'get', request_uri) 'http://petstore.swagger.io', method, request_uri)
result = finder.find(request) result = finder.find(request)
path_result = TemplateResult(path_2.name, {'id': token_id}) path_2 = spec / 'paths' / self.path_2_name
operation_2 = spec / 'paths' / self.path_2_name / method
server = eval(self.location) / 'servers' / 0
path_result = TemplateResult(self.path_2_name, {'id': token_id})
server_result = TemplateResult(self.server_url, {}) server_result = TemplateResult(self.server_url, {})
assert result == ( assert result == (
path_2, operation_2, server, path_result, server_result, path_2, operation_2, server, path_result, server_result,
@ -441,10 +496,13 @@ class TestConcretePaths(
BaseTestSpecServer, BaseTestSimpleServer): BaseTestSpecServer, BaseTestSimpleServer):
path_name = '/keys/{id}/tokens' path_name = '/keys/{id}/tokens'
path_2_name = '/keys/master/tokens'
@pytest.fixture @pytest.fixture
def operation_2(self): def operation_2(self):
return Operation('get', '/keys/master/tokens', {}, {}) return {
'responses': [],
}
@pytest.fixture @pytest.fixture
def operations_2(self, operation_2): def operations_2(self, operation_2):
@ -454,26 +512,30 @@ class TestConcretePaths(
@pytest.fixture @pytest.fixture
def path(self, operations): def path(self, operations):
return Path('/keys/{id}/tokens', operations) return operations
@pytest.fixture @pytest.fixture
def path_2(self, operations_2): def path_2(self, operations_2):
return Path('/keys/master/tokens', operations_2) return operations_2
@pytest.fixture @pytest.fixture
def paths(self, path, path_2): def paths(self, path, path_2):
return { return {
path.name: path, self.path_name: path,
path_2.name: path_2, self.path_2_name: path_2,
} }
def test_valid(self, finder, path_2, operation_2, server): def test_valid(self, finder, spec):
request_uri = '/keys/master/tokens' request_uri = '/keys/master/tokens'
method = 'get'
request = MockRequest( request = MockRequest(
'http://petstore.swagger.io', 'get', request_uri) 'http://petstore.swagger.io', method, request_uri)
result = finder.find(request) result = finder.find(request)
path_result = TemplateResult(path_2.name, {}) path_2 = spec / 'paths' / self.path_2_name
operation_2 = spec / 'paths' / self.path_2_name / method
server = eval(self.location) / 'servers' / 0
path_result = TemplateResult(self.path_2_name, {})
server_result = TemplateResult(self.server_url, {}) server_result = TemplateResult(self.server_url, {})
assert result == ( assert result == (
path_2, operation_2, server, path_result, server_result, path_2, operation_2, server, path_result, server_result,
@ -484,10 +546,13 @@ class TestTemplateConcretePaths(
BaseTestSpecServer, BaseTestSimpleServer): BaseTestSpecServer, BaseTestSimpleServer):
path_name = '/keys/{id}/tokens/{id2}' path_name = '/keys/{id}/tokens/{id2}'
path_2_name = '/keys/{id}/tokens/master'
@pytest.fixture @pytest.fixture
def operation_2(self): def operation_2(self):
return Operation('get', '/keys/{id}/tokens/master', {}, {}) return {
'responses': [],
}
@pytest.fixture @pytest.fixture
def operations_2(self, operation_2): def operations_2(self, operation_2):
@ -497,27 +562,31 @@ class TestTemplateConcretePaths(
@pytest.fixture @pytest.fixture
def path(self, operations): def path(self, operations):
return Path('/keys/{id}/tokens/{id2}', operations) return operations
@pytest.fixture @pytest.fixture
def path_2(self, operations_2): def path_2(self, operations_2):
return Path('/keys/{id}/tokens/master', operations_2) return operations_2
@pytest.fixture @pytest.fixture
def paths(self, path, path_2): def paths(self, path, path_2):
return { return {
path.name: path, self.path_name: path,
path_2.name: path_2, self.path_2_name: path_2,
} }
def test_valid(self, finder, path_2, operation_2, server): def test_valid(self, finder, spec):
token_id = '123' token_id = '123'
request_uri = '/keys/{0}/tokens/master'.format(token_id) request_uri = '/keys/{0}/tokens/master'.format(token_id)
method = 'get'
request = MockRequest( request = MockRequest(
'http://petstore.swagger.io', 'get', request_uri) 'http://petstore.swagger.io', method, request_uri)
result = finder.find(request) result = finder.find(request)
path_result = TemplateResult(path_2.name, {'id': '123'}) path_2 = spec / 'paths' / self.path_2_name
operation_2 = spec / 'paths' / self.path_2_name / method
server = eval(self.location) / 'servers' / 0
path_result = TemplateResult(self.path_2_name, {'id': '123'})
server_result = TemplateResult(self.server_url, {}) server_result = TemplateResult(self.server_url, {})
assert result == ( assert result == (
path_2, operation_2, server, path_result, server_result, path_2, operation_2, server, path_result, server_result,

View file

@ -1,13 +1,14 @@
import mock import mock
import pytest import pytest
from openapi_core.spec.paths import SpecPath
from openapi_core.templating.responses.finders import ResponseFinder from openapi_core.templating.responses.finders import ResponseFinder
class TestResponses(object): class TestResponses(object):
@pytest.fixture(scope='class') @pytest.fixture(scope='class')
def responses(self): def spec(self):
return { return {
'200': mock.sentinel.response_200, '200': mock.sentinel.response_200,
'299': mock.sentinel.response_299, '299': mock.sentinel.response_299,
@ -15,6 +16,10 @@ class TestResponses(object):
'default': mock.sentinel.response_default, 'default': mock.sentinel.response_default,
} }
@pytest.fixture(scope='class')
def responses(self, spec):
return SpecPath.from_spec(spec)
@pytest.fixture(scope='class') @pytest.fixture(scope='class')
def finder(self, responses): def finder(self, responses):
return ResponseFinder(responses) return ResponseFinder(responses)
@ -22,14 +27,14 @@ class TestResponses(object):
def test_default(self, finder, responses): def test_default(self, finder, responses):
response = finder.find() response = finder.find()
assert response == responses['default'] assert response == responses / 'default'
def test_range(self, finder, responses): def test_range(self, finder, responses):
response = finder.find('201') response = finder.find('201')
assert response == responses['2XX'] assert response == responses / '2XX'
def test_exact(self, finder, responses): def test_exact(self, finder, responses):
response = finder.find('200') response = finder.find('200')
assert response == responses['200'] assert response == responses / '200'

View file

@ -4,11 +4,8 @@ import uuid
from isodate.tzinfo import UTC, FixedOffset from isodate.tzinfo import UTC, FixedOffset
import pytest import pytest
from openapi_core.schema.media_types.models import MediaType
from openapi_core.schema.parameters.models import Parameter
from openapi_core.schema.schemas.enums import SchemaType
from openapi_core.schema.schemas.models import Schema
from openapi_core.schema.schemas.types import NoValue from openapi_core.schema.schemas.types import NoValue
from openapi_core.spec.paths import SpecPath
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
from openapi_core.unmarshalling.schemas.exceptions import ( from openapi_core.unmarshalling.schemas.exceptions import (
InvalidSchemaFormatValue, InvalidSchemaValue, UnmarshalError, InvalidSchemaFormatValue, InvalidSchemaValue, UnmarshalError,
@ -33,22 +30,24 @@ def unmarshaller_factory():
return create_unmarshaller return create_unmarshaller
class TestParameterUnmarshal(object): class TestUnmarshal(object):
def test_no_schema(self, unmarshaller_factory): def test_no_schema(self, unmarshaller_factory):
param = Parameter('param', 'query') schema = None
value = 'test' value = 'test'
with pytest.raises(TypeError): with pytest.raises(TypeError):
unmarshaller_factory(param.schema).unmarshal(value) unmarshaller_factory(schema).unmarshal(value)
def test_schema_type_invalid(self, unmarshaller_factory): def test_schema_type_invalid(self, unmarshaller_factory):
schema = Schema('integer', _source={'type': 'integer'}) spec = {
param = Parameter('param', 'query', schema=schema) 'type': 'integer',
}
schema = SpecPath.from_spec(spec)
value = 'test' value = 'test'
with pytest.raises(InvalidSchemaFormatValue): with pytest.raises(InvalidSchemaFormatValue):
unmarshaller_factory(param.schema).unmarshal(value) unmarshaller_factory(schema).unmarshal(value)
def test_schema_custom_format_invalid(self, unmarshaller_factory): def test_schema_custom_format_invalid(self, unmarshaller_factory):
@ -60,59 +59,16 @@ class TestParameterUnmarshal(object):
custom_formatters = { custom_formatters = {
custom_format: formatter, custom_format: formatter,
} }
schema = Schema( spec = {
'string', 'type': 'string',
schema_format=custom_format, 'format': 'custom',
_source={'type': 'string', 'format': 'custom'},
)
param = Parameter('param', 'query', schema=schema)
value = 'test'
with pytest.raises(InvalidSchemaFormatValue):
unmarshaller_factory(
param.schema,
custom_formatters=custom_formatters,
).unmarshal(value)
class TestMediaTypeUnmarshal(object):
def test_no_schema(self, unmarshaller_factory):
media_type = MediaType('application/json')
value = 'test'
with pytest.raises(TypeError):
unmarshaller_factory(media_type.schema).unmarshal(value)
def test_schema_type_invalid(self, unmarshaller_factory):
schema = Schema('integer', _source={'type': 'integer'})
media_type = MediaType('application/json', schema=schema)
value = 'test'
with pytest.raises(InvalidSchemaFormatValue):
unmarshaller_factory(media_type.schema).unmarshal(value)
def test_schema_custom_format_invalid(self, unmarshaller_factory):
class CustomFormatter(Formatter):
def unmarshal(self, value):
raise ValueError
formatter = CustomFormatter()
custom_format = 'custom'
custom_formatters = {
custom_format: formatter,
} }
schema = Schema( schema = SpecPath.from_spec(spec)
'string',
schema_format=custom_format,
_source={'type': 'string', 'format': 'custom'},
)
media_type = MediaType('application/json', schema=schema)
value = 'test' value = 'test'
with pytest.raises(InvalidSchemaFormatValue): with pytest.raises(InvalidSchemaFormatValue):
unmarshaller_factory( unmarshaller_factory(
media_type.schema, schema,
custom_formatters=custom_formatters, custom_formatters=custom_formatters,
).unmarshal(value) ).unmarshal(value)
@ -120,7 +76,11 @@ class TestMediaTypeUnmarshal(object):
class TestSchemaUnmarshallerCall(object): class TestSchemaUnmarshallerCall(object):
def test_deprecated(self, unmarshaller_factory): def test_deprecated(self, unmarshaller_factory):
schema = Schema('string', deprecated=True) spec = {
'type': 'string',
'deprecated': True,
}
schema = SpecPath.from_spec(spec)
value = 'test' value = 'test'
with pytest.warns(DeprecationWarning): with pytest.warns(DeprecationWarning):
@ -132,14 +92,20 @@ class TestSchemaUnmarshallerCall(object):
'boolean', 'array', 'integer', 'number', 'boolean', 'array', 'integer', 'number',
]) ])
def test_non_string_empty_value(self, schema_type, unmarshaller_factory): def test_non_string_empty_value(self, schema_type, unmarshaller_factory):
schema = Schema(schema_type) spec = {
'type': schema_type,
}
schema = SpecPath.from_spec(spec)
value = '' value = ''
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_string_valid(self, unmarshaller_factory): def test_string_valid(self, unmarshaller_factory):
schema = Schema('string') spec = {
'type': 'string',
}
schema = SpecPath.from_spec(spec)
value = 'test' value = 'test'
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -147,7 +113,11 @@ class TestSchemaUnmarshallerCall(object):
assert result == value assert result == value
def test_string_format_uuid_valid(self, unmarshaller_factory): def test_string_format_uuid_valid(self, unmarshaller_factory):
schema = Schema(SchemaType.STRING, schema_format='uuid') spec = {
'type': 'string',
'format': 'uuid',
}
schema = SpecPath.from_spec(spec)
value = str(uuid.uuid4()) value = str(uuid.uuid4())
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -156,14 +126,22 @@ class TestSchemaUnmarshallerCall(object):
def test_string_format_uuid_uuid_quirks_invalid( def test_string_format_uuid_uuid_quirks_invalid(
self, unmarshaller_factory): self, unmarshaller_factory):
schema = Schema(SchemaType.STRING, schema_format='uuid') spec = {
'type': 'string',
'format': 'uuid',
}
schema = SpecPath.from_spec(spec)
value = uuid.uuid4() value = uuid.uuid4()
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_string_format_password(self, unmarshaller_factory): def test_string_format_password(self, unmarshaller_factory):
schema = Schema(SchemaType.STRING, schema_format='password') spec = {
'type': 'string',
'format': 'password',
}
schema = SpecPath.from_spec(spec)
value = 'password' value = 'password'
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -171,7 +149,10 @@ class TestSchemaUnmarshallerCall(object):
assert result == 'password' assert result == 'password'
def test_string_float_invalid(self, unmarshaller_factory): def test_string_float_invalid(self, unmarshaller_factory):
schema = Schema('string') spec = {
'type': 'string',
}
schema = SpecPath.from_spec(spec)
value = 1.23 value = 1.23
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
@ -179,7 +160,11 @@ class TestSchemaUnmarshallerCall(object):
def test_string_default(self, unmarshaller_factory): def test_string_default(self, unmarshaller_factory):
default_value = 'default' default_value = 'default'
schema = Schema('string', default=default_value) spec = {
'type': 'string',
'default': default_value,
}
schema = SpecPath.from_spec(spec)
value = NoValue value = NoValue
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -189,7 +174,12 @@ class TestSchemaUnmarshallerCall(object):
@pytest.mark.parametrize('default_value', ['default', None]) @pytest.mark.parametrize('default_value', ['default', None])
def test_string_default_nullable( def test_string_default_nullable(
self, default_value, unmarshaller_factory): self, default_value, unmarshaller_factory):
schema = Schema('string', default=default_value, nullable=True) spec = {
'type': 'string',
'default': default_value,
'nullable': True,
}
schema = SpecPath.from_spec(spec)
value = NoValue value = NoValue
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -197,7 +187,11 @@ class TestSchemaUnmarshallerCall(object):
assert result == default_value assert result == default_value
def test_string_format_date(self, unmarshaller_factory): def test_string_format_date(self, unmarshaller_factory):
schema = Schema('string', schema_format='date') spec = {
'type': 'string',
'format': 'date',
}
schema = SpecPath.from_spec(spec)
value = '2018-01-02' value = '2018-01-02'
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -205,14 +199,22 @@ class TestSchemaUnmarshallerCall(object):
assert result == datetime.date(2018, 1, 2) assert result == datetime.date(2018, 1, 2)
def test_string_format_datetime_invalid(self, unmarshaller_factory): def test_string_format_datetime_invalid(self, unmarshaller_factory):
schema = Schema('string', schema_format='date-time') spec = {
'type': 'string',
'format': 'date-time',
}
schema = SpecPath.from_spec(spec)
value = '2018-01-02T00:00:00' value = '2018-01-02T00:00:00'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_string_format_datetime_utc(self, unmarshaller_factory): def test_string_format_datetime_utc(self, unmarshaller_factory):
schema = Schema('string', schema_format='date-time') spec = {
'type': 'string',
'format': 'date-time',
}
schema = SpecPath.from_spec(spec)
value = '2018-01-02T00:00:00Z' value = '2018-01-02T00:00:00Z'
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -221,7 +223,11 @@ class TestSchemaUnmarshallerCall(object):
assert result == datetime.datetime(2018, 1, 2, 0, 0, tzinfo=tzinfo) assert result == datetime.datetime(2018, 1, 2, 0, 0, tzinfo=tzinfo)
def test_string_format_datetime_tz(self, unmarshaller_factory): def test_string_format_datetime_tz(self, unmarshaller_factory):
schema = Schema('string', schema_format='date-time') spec = {
'type': 'string',
'format': 'date-time',
}
schema = SpecPath.from_spec(spec)
value = '2020-04-01T12:00:00+02:00' value = '2020-04-01T12:00:00+02:00'
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -236,7 +242,11 @@ class TestSchemaUnmarshallerCall(object):
def unmarshal(self, value): def unmarshal(self, value):
return formatted return formatted
custom_format = 'custom' custom_format = 'custom'
schema = Schema('string', schema_format=custom_format) spec = {
'type': 'string',
'format': custom_format,
}
schema = SpecPath.from_spec(spec)
value = 'x' value = 'x'
formatter = CustomFormatter() formatter = CustomFormatter()
custom_formatters = { custom_formatters = {
@ -254,7 +264,11 @@ class TestSchemaUnmarshallerCall(object):
def unmarshal(self, value): def unmarshal(self, value):
raise ValueError raise ValueError
custom_format = 'custom' custom_format = 'custom'
schema = Schema('string', schema_format=custom_format) spec = {
'type': 'string',
'format': custom_format,
}
schema = SpecPath.from_spec(spec)
value = 'x' value = 'x'
formatter = CustomFormatter() formatter = CustomFormatter()
custom_formatters = { custom_formatters = {
@ -267,7 +281,11 @@ class TestSchemaUnmarshallerCall(object):
def test_string_format_unknown(self, unmarshaller_factory): def test_string_format_unknown(self, unmarshaller_factory):
unknown_format = 'unknown' unknown_format = 'unknown'
schema = Schema('string', schema_format=unknown_format) spec = {
'type': 'string',
'format': unknown_format,
}
schema = SpecPath.from_spec(spec)
value = 'x' value = 'x'
with pytest.raises(FormatterNotFoundError): with pytest.raises(FormatterNotFoundError):
@ -275,7 +293,11 @@ class TestSchemaUnmarshallerCall(object):
def test_string_format_invalid_value(self, unmarshaller_factory): def test_string_format_invalid_value(self, unmarshaller_factory):
custom_format = 'custom' custom_format = 'custom'
schema = Schema('string', schema_format=custom_format) spec = {
'type': 'string',
'format': custom_format,
}
schema = SpecPath.from_spec(spec)
value = 'x' value = 'x'
with pytest.raises( with pytest.raises(
@ -287,7 +309,10 @@ class TestSchemaUnmarshallerCall(object):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_integer_valid(self, unmarshaller_factory): def test_integer_valid(self, unmarshaller_factory):
schema = Schema('integer') spec = {
'type': 'integer',
}
schema = SpecPath.from_spec(spec)
value = 123 value = 123
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -295,21 +320,32 @@ class TestSchemaUnmarshallerCall(object):
assert result == int(value) assert result == int(value)
def test_integer_string_invalid(self, unmarshaller_factory): def test_integer_string_invalid(self, unmarshaller_factory):
schema = Schema('integer') spec = {
'type': 'integer',
}
schema = SpecPath.from_spec(spec)
value = '123' value = '123'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_integer_enum_invalid(self, unmarshaller_factory): def test_integer_enum_invalid(self, unmarshaller_factory):
schema = Schema('integer', enum=[1, 2, 3]) spec = {
'type': 'integer',
'enum': [1, 2, 3],
}
schema = SpecPath.from_spec(spec)
value = '123' value = '123'
with pytest.raises(UnmarshalError): with pytest.raises(UnmarshalError):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_integer_enum(self, unmarshaller_factory): def test_integer_enum(self, unmarshaller_factory):
schema = Schema('integer', enum=[1, 2, 3]) spec = {
'type': 'integer',
'enum': [1, 2, 3],
}
schema = SpecPath.from_spec(spec)
value = 2 value = 2
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -317,7 +353,11 @@ class TestSchemaUnmarshallerCall(object):
assert result == int(value) assert result == int(value)
def test_integer_enum_string_invalid(self, unmarshaller_factory): def test_integer_enum_string_invalid(self, unmarshaller_factory):
schema = Schema('integer', enum=[1, 2, 3]) spec = {
'type': 'integer',
'enum': [1, 2, 3],
}
schema = SpecPath.from_spec(spec)
value = '2' value = '2'
with pytest.raises(UnmarshalError): with pytest.raises(UnmarshalError):
@ -325,7 +365,11 @@ class TestSchemaUnmarshallerCall(object):
def test_integer_default(self, unmarshaller_factory): def test_integer_default(self, unmarshaller_factory):
default_value = 123 default_value = 123
schema = Schema('integer', default=default_value) spec = {
'type': 'integer',
'default': default_value,
}
schema = SpecPath.from_spec(spec)
value = NoValue value = NoValue
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -334,7 +378,12 @@ class TestSchemaUnmarshallerCall(object):
def test_integer_default_nullable(self, unmarshaller_factory): def test_integer_default_nullable(self, unmarshaller_factory):
default_value = 123 default_value = 123
schema = Schema('integer', default=default_value, nullable=True) spec = {
'type': 'integer',
'default': default_value,
'nullable': True,
}
schema = SpecPath.from_spec(spec)
value = None value = None
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -342,14 +391,23 @@ class TestSchemaUnmarshallerCall(object):
assert result is None assert result is None
def test_integer_invalid(self, unmarshaller_factory): def test_integer_invalid(self, unmarshaller_factory):
schema = Schema('integer') spec = {
'type': 'integer',
}
schema = SpecPath.from_spec(spec)
value = 'abc' value = 'abc'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_array_valid(self, unmarshaller_factory): def test_array_valid(self, unmarshaller_factory):
schema = Schema('array', items=Schema('integer')) spec = {
'type': 'array',
'items': {
'type': 'integer',
}
}
schema = SpecPath.from_spec(spec)
value = [1, 2, 3] value = [1, 2, 3]
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -357,42 +415,63 @@ class TestSchemaUnmarshallerCall(object):
assert result == value assert result == value
def test_array_null(self, unmarshaller_factory): def test_array_null(self, unmarshaller_factory):
schema = Schema( spec = {
'array', 'type': 'array',
items=Schema('integer'), 'items': {
) 'type': 'integer',
}
}
schema = SpecPath.from_spec(spec)
value = None value = None
with pytest.raises(TypeError): with pytest.raises(TypeError):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_array_nullable(self, unmarshaller_factory): def test_array_nullable(self, unmarshaller_factory):
schema = Schema( spec = {
'array', 'type': 'array',
items=Schema('integer'), 'items': {
nullable=True, 'type': 'integer',
) },
'nullable': True,
}
schema = SpecPath.from_spec(spec)
value = None value = None
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
assert result is None assert result is None
def test_array_of_string_string_invalid(self, unmarshaller_factory): def test_array_of_string_string_invalid(self, unmarshaller_factory):
schema = Schema('array', items=Schema('string')) spec = {
'type': 'array',
'items': {
'type': 'string',
}
}
schema = SpecPath.from_spec(spec)
value = '123' value = '123'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_array_of_integer_string_invalid(self, unmarshaller_factory): def test_array_of_integer_string_invalid(self, unmarshaller_factory):
schema = Schema('array', items=Schema('integer')) spec = {
'type': 'array',
'items': {
'type': 'integer',
}
}
schema = SpecPath.from_spec(spec)
value = '123' value = '123'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_boolean_valid(self, unmarshaller_factory): def test_boolean_valid(self, unmarshaller_factory):
schema = Schema('boolean') spec = {
'type': 'boolean',
}
schema = SpecPath.from_spec(spec)
value = True value = True
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -400,14 +479,20 @@ class TestSchemaUnmarshallerCall(object):
assert result == value assert result == value
def test_boolean_string_invalid(self, unmarshaller_factory): def test_boolean_string_invalid(self, unmarshaller_factory):
schema = Schema('boolean') spec = {
'type': 'boolean',
}
schema = SpecPath.from_spec(spec)
value = 'True' value = 'True'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_number_valid(self, unmarshaller_factory): def test_number_valid(self, unmarshaller_factory):
schema = Schema('number') spec = {
'type': 'number',
}
schema = SpecPath.from_spec(spec)
value = 1.23 value = 1.23
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -415,14 +500,20 @@ class TestSchemaUnmarshallerCall(object):
assert result == value assert result == value
def test_number_string_invalid(self, unmarshaller_factory): def test_number_string_invalid(self, unmarshaller_factory):
schema = Schema('number') spec = {
'type': 'number',
}
schema = SpecPath.from_spec(spec)
value = '1.23' value = '1.23'
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_number_int(self, unmarshaller_factory): def test_number_int(self, unmarshaller_factory):
schema = Schema('number') spec = {
'type': 'number',
}
schema = SpecPath.from_spec(spec)
value = 1 value = 1
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -430,7 +521,10 @@ class TestSchemaUnmarshallerCall(object):
assert type(result) == int assert type(result) == int
def test_number_float(self, unmarshaller_factory): def test_number_float(self, unmarshaller_factory):
schema = Schema('number') spec = {
'type': 'number',
}
schema = SpecPath.from_spec(spec)
value = 1.2 value = 1.2
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -438,42 +532,72 @@ class TestSchemaUnmarshallerCall(object):
assert type(result) == float assert type(result) == float
def test_number_format_float(self, unmarshaller_factory): def test_number_format_float(self, unmarshaller_factory):
schema = Schema('number', schema_format='float') spec = {
'type': 'number',
'format': 'float',
}
schema = SpecPath.from_spec(spec)
value = 1.2 value = 1.2
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
assert result == 1.2 assert result == 1.2
def test_number_format_double(self, unmarshaller_factory): def test_number_format_double(self, unmarshaller_factory):
schema = Schema('number', schema_format='double') spec = {
'type': 'number',
'format': 'double',
}
schema = SpecPath.from_spec(spec)
value = 1.2 value = 1.2
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
assert result == 1.2 assert result == 1.2
def test_object_nullable(self, unmarshaller_factory): def test_object_nullable(self, unmarshaller_factory):
schema = Schema( spec = {
'object', 'type': 'object',
properties={ 'properties': {
'foo': Schema('object', nullable=True), 'foo': {
'type': 'object',
'nullable': True,
}
}, },
) }
schema = SpecPath.from_spec(spec)
value = {'foo': None} value = {'foo': None}
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
assert result == {'foo': None} assert result == {'foo': None}
def test_schema_any_one_of(self, unmarshaller_factory): def test_schema_any_one_of(self, unmarshaller_factory):
schema = Schema(one_of=[ spec = {
Schema('string'), 'oneOf': [
Schema('array', items=Schema('string')), {
]) 'type': 'string',
},
{
'type': 'array',
'items': {
'type': 'string',
}
}
],
}
schema = SpecPath.from_spec(spec)
assert unmarshaller_factory(schema)(['hello']) == ['hello'] assert unmarshaller_factory(schema)(['hello']) == ['hello']
def test_schema_any_all_of(self, unmarshaller_factory): def test_schema_any_all_of(self, unmarshaller_factory):
schema = Schema(all_of=[ spec = {
Schema('array', items=Schema('string')), 'allOf': [
]) {
'type': 'array',
'items': {
'type': 'string',
}
}
],
}
schema = SpecPath.from_spec(spec)
assert unmarshaller_factory(schema)(['hello']) == ['hello'] assert unmarshaller_factory(schema)(['hello']) == ['hello']
@pytest.mark.parametrize('value', [ @pytest.mark.parametrize('value', [
@ -499,34 +623,45 @@ class TestSchemaUnmarshallerCall(object):
]) ])
def test_schema_any_all_of_invalid_properties( def test_schema_any_all_of_invalid_properties(
self, value, unmarshaller_factory): self, value, unmarshaller_factory):
schema = Schema( spec = {
all_of=[ 'allOf': [
Schema( {
'object', 'type': 'object',
required=['somestr'], 'required': ['somestr'],
properties={ 'properties': {
'somestr': Schema('string'), 'somestr': {
'type': 'string',
},
}, },
), },
Schema( {
'object', 'type': 'object',
required=['someint'], 'required': ['someint'],
properties={ 'properties': {
'someint': Schema('integer'), 'someint': {
'type': 'integer',
},
}, },
), }
], ],
additional_properties=False, 'additionalProperties': False,
) }
schema = SpecPath.from_spec(spec)
with pytest.raises(InvalidSchemaValue): with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(schema)(value) unmarshaller_factory(schema)(value)
def test_schema_any_all_of_any(self, unmarshaller_factory): def test_schema_any_all_of_any(self, unmarshaller_factory):
schema = Schema(all_of=[ spec = {
Schema(), 'allOf': [
Schema('string', schema_format='date'), {},
]) {
'type': 'string',
'format': 'date',
},
],
}
schema = SpecPath.from_spec(spec)
value = '2018-01-02' value = '2018-01-02'
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
@ -534,7 +669,8 @@ class TestSchemaUnmarshallerCall(object):
assert result == datetime.date(2018, 1, 2) assert result == datetime.date(2018, 1, 2)
def test_schema_any(self, unmarshaller_factory): def test_schema_any(self, unmarshaller_factory):
schema = Schema() spec = {}
schema = SpecPath.from_spec(spec)
assert unmarshaller_factory(schema)('string') == 'string' assert unmarshaller_factory(schema)('string') == 'string'
@pytest.mark.parametrize('value', [ @pytest.mark.parametrize('value', [
@ -542,21 +678,30 @@ class TestSchemaUnmarshallerCall(object):
{'foo': 'bar', 'bar': 'foo'}, {'foo': 'bar', 'bar': 'foo'},
{'additional': {'bar': 1}}, {'additional': {'bar': 1}},
]) ])
@pytest.mark.parametrize('additional_properties', [True, Schema()]) @pytest.mark.parametrize('additional_properties', [True, {}])
def test_schema_free_form_object( def test_schema_free_form_object(
self, value, additional_properties, unmarshaller_factory): self, value, additional_properties, unmarshaller_factory):
schema = Schema('object', additional_properties=additional_properties) spec = {
'type': 'object',
'additionalProperties': additional_properties,
}
schema = SpecPath.from_spec(spec)
result = unmarshaller_factory(schema)(value) result = unmarshaller_factory(schema)(value)
assert result == value assert result == value
def test_read_only_properties(self, unmarshaller_factory): def test_read_only_properties(self, unmarshaller_factory):
id_property = Schema('integer', read_only=True) spec = {
'type': 'object',
def properties(): 'required': ['id'],
yield ('id', id_property) 'properties': {
'id': {
obj_schema = Schema('object', properties=properties(), required=['id']) 'type': 'integer',
'readOnly': True,
}
},
}
obj_schema = SpecPath.from_spec(spec)
# readOnly properties may be admitted in a Response context # readOnly properties may be admitted in a Response context
result = unmarshaller_factory( result = unmarshaller_factory(
@ -565,19 +710,36 @@ class TestSchemaUnmarshallerCall(object):
'id': 10, 'id': 10,
} }
# readOnly properties are not admitted on a Request context def test_read_only_properties_invalid(self, unmarshaller_factory):
result = unmarshaller_factory( spec = {
obj_schema, context=UnmarshalContext.REQUEST)({"id": 10}) 'type': 'object',
'required': ['id'],
'properties': {
'id': {
'type': 'integer',
'readOnly': True,
}
},
}
obj_schema = SpecPath.from_spec(spec)
assert result == {} # readOnly properties are not admitted on a Request context
with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(
obj_schema, context=UnmarshalContext.REQUEST)({"id": 10})
def test_write_only_properties(self, unmarshaller_factory): def test_write_only_properties(self, unmarshaller_factory):
id_property = Schema('integer', write_only=True) spec = {
'type': 'object',
def properties(): 'required': ['id'],
yield ('id', id_property) 'properties': {
'id': {
obj_schema = Schema('object', properties=properties(), required=['id']) 'type': 'integer',
'writeOnly': True,
}
},
}
obj_schema = SpecPath.from_spec(spec)
# readOnly properties may be admitted in a Response context # readOnly properties may be admitted in a Response context
result = unmarshaller_factory( result = unmarshaller_factory(
@ -586,8 +748,20 @@ class TestSchemaUnmarshallerCall(object):
'id': 10, 'id': 10,
} }
# readOnly properties are not admitted on a Request context def test_write_only_properties_invalid(self, unmarshaller_factory):
result = unmarshaller_factory( spec = {
obj_schema, context=UnmarshalContext.RESPONSE)({"id": 10}) 'type': 'object',
'required': ['id'],
'properties': {
'id': {
'type': 'integer',
'writeOnly': True,
}
},
}
obj_schema = SpecPath.from_spec(spec)
assert result == {} # readOnly properties are not admitted on a Request context
with pytest.raises(InvalidSchemaValue):
unmarshaller_factory(
obj_schema, context=UnmarshalContext.RESPONSE)({"id": 10})

File diff suppressed because it is too large Load diff