Split cast and deserialise processes

This commit is contained in:
Artur Maciag 2020-02-03 10:17:27 +00:00
parent 2c1a6c189c
commit b4c10e847a
6 changed files with 52 additions and 23 deletions

View file

@ -31,16 +31,17 @@ class MediaType(object):
deserializer = self.get_dererializer()
return deserializer(value)
def cast(self, value):
def deserialise(self, value):
try:
deserialized = self.deserialize(value)
return self.deserialize(value)
except ValueError as exc:
raise InvalidMediaTypeValue(exc)
def cast(self, value):
if not self.schema:
return deserialized
return value
try:
return self.schema.cast(deserialized)
return self.schema.cast(value)
except CastError as exc:
raise InvalidMediaTypeValue(exc)

View file

@ -88,7 +88,7 @@ class Parameter(object):
return location[self.name]
def cast(self, value):
def deserialise(self, value):
if self.deprecated:
warnings.warn(
"{0} parameter is deprecated".format(self.name),
@ -100,14 +100,15 @@ class Parameter(object):
raise EmptyParameterValue(self.name)
try:
deserialized = self.deserialize(value)
return self.deserialize(value)
except (ValueError, AttributeError) as exc:
raise InvalidParameterValue(self.name, exc)
def cast(self, value):
if not self.schema:
return deserialized
return value
try:
return self.schema.cast(deserialized)
return self.schema.cast(value)
except CastError as exc:
raise InvalidParameterValue(self.name, exc)

View file

@ -109,10 +109,15 @@ class RequestValidator(object):
casted = param.schema.default
else:
try:
casted = param.cast(raw_value)
deserialised = self._deserialise(param, raw_value)
except OpenAPIParameterError as exc:
errors.append(exc)
continue
else:
try:
casted = self._cast(param, deserialised)
except OpenAPIParameterError as exc:
errors.append(exc)
continue
try:
unmarshalled = self._unmarshal(param, casted)
@ -142,17 +147,28 @@ class RequestValidator(object):
errors.append(exc)
else:
try:
casted = media_type.cast(raw_body)
deserialised = self._deserialise(media_type, raw_body)
except InvalidMediaTypeValue as exc:
errors.append(exc)
else:
try:
body = self._unmarshal(media_type, casted)
except (ValidateError, UnmarshalError) as exc:
casted = self._cast(media_type, deserialised)
except InvalidMediaTypeValue as exc:
errors.append(exc)
else:
try:
body = self._unmarshal(media_type, casted)
except (ValidateError, UnmarshalError) as exc:
errors.append(exc)
return body, errors
def _deserialise(self, param_or_media_type, value):
return param_or_media_type.deserialise(value)
def _cast(self, param_or_media_type, value):
return param_or_media_type.cast(value)
def _unmarshal(self, param_or_media_type, value):
if not param_or_media_type.schema:
return value

View file

@ -81,14 +81,19 @@ class ResponseValidator(object):
errors.append(exc)
else:
try:
casted = media_type.cast(raw_data)
deserialised = self._deserialise(media_type, raw_data)
except InvalidMediaTypeValue as exc:
errors.append(exc)
else:
try:
data = self._unmarshal(media_type, casted)
except (ValidateError, UnmarshalError) as exc:
casted = self._cast(media_type, deserialised)
except InvalidMediaTypeValue as exc:
errors.append(exc)
else:
try:
data = self._unmarshal(media_type, casted)
except (ValidateError, UnmarshalError) as exc:
errors.append(exc)
return data, errors
@ -100,6 +105,12 @@ class ResponseValidator(object):
return headers, errors
def _deserialise(self, param_or_media_type, value):
return param_or_media_type.deserialise(value)
def _cast(self, param_or_media_type, value):
return param_or_media_type.cast(value)
def _unmarshal(self, param_or_media_type, value):
if not param_or_media_type.schema:
return value

View file

@ -4,19 +4,19 @@ from openapi_core.schema.media_types.exceptions import InvalidMediaTypeValue
from openapi_core.schema.media_types.models import MediaType
class TestMediaTypeCast(object):
class TestMediaTypeDeserialise(object):
def test_empty(self):
media_type = MediaType('application/json')
value = ''
with pytest.raises(InvalidMediaTypeValue):
media_type.cast(value)
media_type.deserialise(value)
def test_no_schema_deserialised(self):
media_type = MediaType('application/json')
value = "{}"
result = media_type.cast(value)
result = media_type.deserialise(value)
assert result == {}

View file

@ -38,14 +38,14 @@ class TestParameterInit(object):
assert param.explode is True
class TestParameterCast(object):
class TestParameterDeserialise(object):
def test_deprecated(self):
param = Parameter('param', 'query', deprecated=True)
value = 'test'
with pytest.warns(DeprecationWarning):
result = param.cast(value)
result = param.deserialise(value)
assert result == value
@ -54,12 +54,12 @@ class TestParameterCast(object):
value = ''
with pytest.raises(EmptyParameterValue):
param.cast(value)
param.deserialise(value)
def test_query_valid(self):
param = Parameter('param', 'query')
value = 'test'
result = param.cast(value)
result = param.deserialise(value)
assert result == value