From ab2e10f99815dade2d4235750f016f3b6a4107d5 Mon Sep 17 00:00:00 2001 From: Artur Maciag Date: Mon, 3 Feb 2020 16:30:31 +0000 Subject: [PATCH] Custom media type deserializers --- README.rst | 25 +++++++++++++++++++ .../deserializing/media_types/factories.py | 16 +++++++++--- .../media_types/util.py | 0 openapi_core/validation/request/validators.py | 9 +++++-- .../validation/response/validators.py | 9 +++++-- tests/unit/deserializing/test_deserialize.py | 21 ++++++++++++++-- 6 files changed, 71 insertions(+), 9 deletions(-) rename openapi_core/{schema => deserializing}/media_types/util.py (100%) diff --git a/README.rst b/README.rst index 97399a8..75b2c88 100644 --- a/README.rst +++ b/README.rst @@ -111,6 +111,31 @@ and unmarshal response data from validation result Response object should be instance of OpenAPIResponse class (See `Integrations`_). +Customizations +############## + +Deserializers +************* + +Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `RequestValidator` or `ResponseValidator` constructor: + +.. code-block:: python + + def protobuf_deserializer(message): + feature = route_guide_pb2.Feature() + feature.ParseFromString(message) + return feature + + custom_media_type_deserializers = { + 'application/protobuf': protobuf_deserializer, + } + + validator = ResponseValidator( + spec, custom_media_type_deserializers=custom_media_type_deserializers) + + result = validator.validate(request, response) + + Integrations ############ diff --git a/openapi_core/deserializing/media_types/factories.py b/openapi_core/deserializing/media_types/factories.py index 4877f7b..a6701c1 100644 --- a/openapi_core/deserializing/media_types/factories.py +++ b/openapi_core/deserializing/media_types/factories.py @@ -1,4 +1,4 @@ -from openapi_core.schema.media_types.util import json_loads +from openapi_core.deserializing.media_types.util import json_loads from openapi_core.deserializing.media_types.deserializers import ( PrimitiveDeserializer, @@ -11,8 +11,18 @@ class MediaTypeDeserializersFactory(object): 'application/json': json_loads, } + def __init__(self, custom_deserializers=None): + if custom_deserializers is None: + custom_deserializers = {} + self.custom_deserializers = custom_deserializers + def create(self, media_type): - deserialize_callable = self.MEDIA_TYPE_DESERIALIZERS.get( - media_type.mimetype, lambda x: x) + deserialize_callable = self.get_deserializer_callable( + media_type.mimetype) return PrimitiveDeserializer( media_type.mimetype, deserialize_callable) + + def get_deserializer_callable(self, mimetype): + if mimetype in self.custom_deserializers: + return self.custom_deserializers[mimetype] + return self.MEDIA_TYPE_DESERIALIZERS.get(mimetype, lambda x: x) diff --git a/openapi_core/schema/media_types/util.py b/openapi_core/deserializing/media_types/util.py similarity index 100% rename from openapi_core/schema/media_types/util.py rename to openapi_core/deserializing/media_types/util.py diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index bc57be0..f77cf94 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -23,9 +23,13 @@ from openapi_core.validation.util import get_operation_pattern class RequestValidator(object): - def __init__(self, spec, custom_formatters=None): + def __init__( + self, spec, + custom_formatters=None, custom_media_type_deserializers=None, + ): self.spec = spec self.custom_formatters = custom_formatters + self.custom_media_type_deserializers = custom_media_type_deserializers def validate(self, request): try: @@ -187,7 +191,8 @@ class RequestValidator(object): from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) - deserializers_factory = MediaTypeDeserializersFactory() + deserializers_factory = MediaTypeDeserializersFactory( + self.custom_media_type_deserializers) deserializer = deserializers_factory.create(media_type) return deserializer(value) diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index 67db73f..241e8d9 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -16,9 +16,13 @@ from openapi_core.validation.util import get_operation_pattern class ResponseValidator(object): - def __init__(self, spec, custom_formatters=None): + def __init__( + self, spec, + custom_formatters=None, custom_media_type_deserializers=None, + ): self.spec = spec self.custom_formatters = custom_formatters + self.custom_media_type_deserializers = custom_media_type_deserializers def validate(self, request, response): try: @@ -112,7 +116,8 @@ class ResponseValidator(object): from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) - deserializers_factory = MediaTypeDeserializersFactory() + deserializers_factory = MediaTypeDeserializersFactory( + self.custom_media_type_deserializers) deserializer = deserializers_factory.create(media_type) return deserializer(value) diff --git a/tests/unit/deserializing/test_deserialize.py b/tests/unit/deserializing/test_deserialize.py index aa800c3..f5b2921 100644 --- a/tests/unit/deserializing/test_deserialize.py +++ b/tests/unit/deserializing/test_deserialize.py @@ -51,8 +51,9 @@ class TestMediaTypeDeserialise(object): @pytest.fixture def deserializer_factory(self): - def create_deserializer(media_type): - return MediaTypeDeserializersFactory().create(media_type) + def create_deserializer(media_type, custom_deserializers=None): + return MediaTypeDeserializersFactory( + custom_deserializers=custom_deserializers).create(media_type) return create_deserializer def test_empty(self, deserializer_factory): @@ -69,3 +70,19 @@ class TestMediaTypeDeserialise(object): result = deserializer_factory(media_type)(value) assert result == {} + + def test_no_schema_custom_deserialiser(self, deserializer_factory): + custom_mimetype = 'application/custom' + media_type = MediaType(custom_mimetype) + value = "{}" + + def custom_deserializer(value): + return 'custom' + custom_deserializers = { + custom_mimetype: custom_deserializer, + } + + result = deserializer_factory( + media_type, custom_deserializers=custom_deserializers)(value) + + assert result == 'custom'