Add an overrideable custom formatters property

This commit is contained in:
Correl Roush 2021-03-18 11:03:36 -04:00
parent 98cbaec514
commit 5fe5b75285
3 changed files with 78 additions and 2 deletions

View file

@ -1,4 +1,6 @@
import datetime
import json import json
import re
import unittest.mock import unittest.mock
from openapi_core.exceptions import OpenAPIError # type: ignore from openapi_core.exceptions import OpenAPIError # type: ignore
@ -9,6 +11,14 @@ import tornado.testing # type: ignore
from tornado_openapi3.handler import OpenAPIRequestHandler from tornado_openapi3.handler import OpenAPIRequestHandler
class USDateFormatter:
def validate(self, value: str) -> bool:
return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value))
def unmarshal(self, value: str) -> datetime.date:
return datetime.datetime.strptime(value, "%m/%d/%Y").date()
class ResourceHandler(OpenAPIRequestHandler): class ResourceHandler(OpenAPIRequestHandler):
spec_dict = { spec_dict = {
"openapi": "3.0.0", "openapi": "3.0.0",
@ -20,7 +30,10 @@ class ResourceHandler(OpenAPIRequestHandler):
"schemas": { "schemas": {
"resource": { "resource": {
"type": "object", "type": "object",
"properties": {"name": {"type": "string"}}, "properties": {
"name": {"type": "string"},
"date": {"type": "string", "format": "usdate"},
},
"required": ["name"], "required": ["name"],
}, },
}, },
@ -60,6 +73,11 @@ class ResourceHandler(OpenAPIRequestHandler):
} }
}, },
} }
custom_formatters = {
"usdate": USDateFormatter(),
}
custom_media_type_deserializers = { custom_media_type_deserializers = {
"application/vnd.example.resource+json": json.loads, "application/vnd.example.resource+json": json.loads,
} }
@ -98,6 +116,28 @@ class DefaultSchemaTest(tornado.testing.AsyncHTTPTestCase):
self.assertEqual(200, response.code) self.assertEqual(200, response.code)
class DefaultFormatters(tornado.testing.AsyncHTTPTestCase):
def get_app(self) -> tornado.web.Application:
test = self
class RequestHandler(OpenAPIRequestHandler):
async def prepare(self) -> None:
test.assertEqual(dict(), self.custom_formatters)
async def get(self) -> None:
...
return tornado.web.Application(
[
(r"/", RequestHandler),
]
)
def test_schema_must_be_implemented(self) -> None:
response = self.fetch("/")
self.assertEqual(200, response.code)
class DefaultDeserializers(tornado.testing.AsyncHTTPTestCase): class DefaultDeserializers(tornado.testing.AsyncHTTPTestCase):
def get_app(self) -> tornado.web.Application: def get_app(self) -> tornado.web.Application:
test = self test = self
@ -192,6 +232,18 @@ class RequestHandlerTests(tornado.testing.AsyncHTTPTestCase):
) )
self.assertEqual(404, response.code) self.assertEqual(404, response.code)
def test_format_error(self) -> None:
response = self.fetch(
"/resource",
method="POST",
headers={
"Authorization": "Bearer secret",
"Content-Type": "application/vnd.example.resource+json",
},
body=json.dumps({"name": "Name", "date": "2020.01.01"}),
)
self.assertEqual(400, response.code)
def test_unexpected_openapi_error(self) -> None: def test_unexpected_openapi_error(self) -> None:
with unittest.mock.patch( with unittest.mock.patch(
"openapi_core.validation.datatypes.BaseValidationResult.raise_for_errors", "openapi_core.validation.datatypes.BaseValidationResult.raise_for_errors",
@ -216,6 +268,6 @@ class RequestHandlerTests(tornado.testing.AsyncHTTPTestCase):
"Authorization": "Bearer secret", "Authorization": "Bearer secret",
"Content-Type": "application/vnd.example.resource+json", "Content-Type": "application/vnd.example.resource+json",
}, },
body=json.dumps({"name": "Name"}), body=json.dumps({"name": "Name", "date": "01/01/2020"}),
) )
self.assertEqual(200, response.code) self.assertEqual(200, response.code)

View file

@ -61,6 +61,17 @@ class OpenAPIRequestHandler(tornado.web.RequestHandler):
""" """
return create_spec(self.spec_dict, validate_spec=False) return create_spec(self.spec_dict, validate_spec=False)
@property
def custom_formatters(self) -> dict:
"""A dictionary mapping value formats to formatter objects.
A formatter object must provide:
- validate(self, value) -> bool
- unmarshal(self, value) -> Any
"""
return dict()
@property @property
def custom_media_type_deserializers(self) -> dict: def custom_media_type_deserializers(self) -> dict:
"""A dictionary mapping media types to deserializing functions. """A dictionary mapping media types to deserializing functions.
@ -115,6 +126,7 @@ class OpenAPIRequestHandler(tornado.web.RequestHandler):
validator = RequestValidator( validator = RequestValidator(
self.spec, self.spec,
custom_formatters=self.custom_formatters,
custom_media_type_deserializers=self.custom_media_type_deserializers, custom_media_type_deserializers=self.custom_media_type_deserializers,
) )
result = validator.validate(self.request) result = validator.validate(self.request)

View file

@ -40,6 +40,17 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
""" """
return create_spec(self.spec_dict) return create_spec(self.spec_dict)
@property
def custom_formatters(self) -> dict:
"""A dictionary mapping value formats to formatter objects.
A formatter object must provide:
- validate(self, value) -> bool
- unmarshal(self, value) -> Any
"""
return dict()
@property @property
def custom_media_type_deserializers(self) -> dict: def custom_media_type_deserializers(self) -> dict:
"""A dictionary mapping media types to deserializing functions. """A dictionary mapping media types to deserializing functions.
@ -61,6 +72,7 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
super().setUp() super().setUp()
self.validator = ResponseValidator( self.validator = ResponseValidator(
self.spec, self.spec,
custom_formatters=self.custom_formatters,
custom_media_type_deserializers=self.custom_media_type_deserializers, custom_media_type_deserializers=self.custom_media_type_deserializers,
) )