diff --git a/tests/test_handler.py b/tests/test_handler.py new file mode 100644 index 0000000..f58a89e --- /dev/null +++ b/tests/test_handler.py @@ -0,0 +1,176 @@ +import json +import unittest.mock + +from openapi_core.exceptions import OpenAPIError # type: ignore +import tornado.httpclient +import tornado.web +import tornado.testing + +from tornado_openapi3.handler import OpenAPIRequestHandler + + +class ResourceHandler(OpenAPIRequestHandler): + spec = { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + }, + "components": { + "schemas": { + "resource": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + }, + "securitySchemes": { + "basicAuth": { + "type": "http", + "scheme": "bearer", + } + }, + }, + "security": [{"basicAuth": []}], + "paths": { + "/resource": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/vnd.example.resource+json": { + "schema": {"$ref": "#/components/schemas/resource"}, + } + }, + }, + "responses": { + "200": { + "description": "Success", + "content": { + "application/vnd.example.resource+json": { + "schema": {"$ref": "#/components/schemas/resource"}, + } + }, + }, + "401": { + "description": "Missing or invalid credentials", + }, + }, + } + } + }, + } + custom_media_type_deserializers = { + "application/vnd.example.resource+json": json.loads, + } + + async def post(self) -> None: + self.set_header("Content-Type", "application/vnd.example.resource+json") + self.finish( + json.dumps( + { + "name": self.validated.body["name"], + } + ) + ) + + +class RequestHandlerTests(tornado.testing.AsyncHTTPTestCase): + def get_app(self) -> tornado.web.Application: + return tornado.web.Application( + [ + (r"/resource", ResourceHandler), + (r"/undocumented", ResourceHandler), + ] + ) + + def test_invalid_operation(self) -> None: + response = self.fetch("/resource") + self.assertEqual(405, response.code) + + def test_bad_data(self) -> None: + response = self.fetch( + "/resource", + method="POST", + headers={ + "Authorization": "Bearer secret", + "Content-Type": "application/vnd.example.resource+json", + }, + body="asdf", + ) + self.assertEqual(400, response.code) + + def test_missing_field(self) -> None: + response = self.fetch( + "/resource", + method="POST", + headers={ + "Authorization": "Bearer secret", + "Content-Type": "application/vnd.example.resource+json", + }, + body=json.dumps({}), + ) + self.assertEqual(400, response.code) + + def test_missing_security(self) -> None: + response = self.fetch( + "/resource", + method="POST", + headers={ + "Content-Type": "application/vnd.example.resource+json", + }, + body=json.dumps({"name": "Name"}), + ) + self.assertEqual(401, response.code) + + def test_invalid_content_type(self) -> None: + response = self.fetch( + "/resource", + method="POST", + headers={ + "Authorization": "Bearer secret", + "Content-Type": "application/json", + }, + body=json.dumps({"name": "Name"}), + ) + self.assertEqual(415, response.code) + + def test_undocumented_endpoint(self) -> None: + response = self.fetch( + "/undocumented", + method="POST", + headers={ + "Authorization": "Bearer secret", + "Content-Type": "application/vnd.example.resource+json", + }, + body=json.dumps({"name": "Name"}), + ) + self.assertEqual(404, response.code) + + def test_unexpected_openapi_error(self) -> None: + with unittest.mock.patch( + "openapi_core.validation.datatypes.BaseValidationResult.raise_for_errors", + side_effect=OpenAPIError, + ): + response = self.fetch( + "/resource", + method="POST", + headers={ + "Authorization": "Bearer secret", + "Content-Type": "application/vnd.example.resource+json", + }, + body=json.dumps({"name": "Name"}), + ) + self.assertEqual(500, response.code) + + def test_success(self) -> None: + response = self.fetch( + "/resource", + method="POST", + headers={ + "Authorization": "Bearer secret", + "Content-Type": "application/vnd.example.resource+json", + }, + body=json.dumps({"name": "Name"}), + ) + self.assertEqual(200, response.code) diff --git a/tornado_openapi3/__init__.py b/tornado_openapi3/__init__.py index 268fa70..8157182 100644 --- a/tornado_openapi3/__init__.py +++ b/tornado_openapi3/__init__.py @@ -1,7 +1,9 @@ +from tornado_openapi3.handler import OpenAPIRequestHandler from tornado_openapi3.requests import RequestValidator, TornadoRequestFactory from tornado_openapi3.responses import ResponseValidator, TornadoResponseFactory __all__ = [ + "OpenAPIRequestHandler", "RequestValidator", "ResponseValidator", "TornadoRequestFactory", diff --git a/tornado_openapi3/handler.py b/tornado_openapi3/handler.py new file mode 100644 index 0000000..cfb9581 --- /dev/null +++ b/tornado_openapi3/handler.py @@ -0,0 +1,56 @@ +import asyncio +import logging + +from openapi_core import create_spec # type: ignore +from openapi_core.exceptions import OpenAPIError # type: ignore +from openapi_core.deserializing.exceptions import DeserializeError # type: ignore +from openapi_core.schema.media_types.exceptions import ( # type: ignore + InvalidContentType, +) +from openapi_core.templating.paths.exceptions import ( # type: ignore + OperationNotFound, + PathNotFound, +) +from openapi_core.unmarshalling.schemas.exceptions import ValidateError # type: ignore +from openapi_core.validation.exceptions import InvalidSecurity # type: ignore +import tornado.web + +from tornado_openapi3.requests import RequestValidator + +logger = logging.getLogger(__name__) + + +class OpenAPIRequestHandler(tornado.web.RequestHandler): + spec: dict = {} + custom_media_type_deserializers: dict = {} + + async def prepare(self) -> None: + maybe_coro = super().prepare() + if maybe_coro and asyncio.iscoroutine(maybe_coro): # pragma: no cover + await maybe_coro + + validator = RequestValidator( + create_spec(self.spec), + custom_media_type_deserializers=self.custom_media_type_deserializers, + ) + result = validator.validate(self.request) + try: + result.raise_for_errors() + except PathNotFound as e: + self.on_openapi_error(404, e) + except OperationNotFound as e: + self.on_openapi_error(405, e) + except (DeserializeError, ValidateError) as e: + self.on_openapi_error(400, e) + except InvalidSecurity as e: + self.on_openapi_error(401, e) + except InvalidContentType as e: + self.on_openapi_error(415, e) + except OpenAPIError as e: + logger.exception("Unexpected validation failure") + self.on_openapi_error(500, e) + self.validated = result + + def on_openapi_error(self, status_code: int, error: OpenAPIError) -> None: + self.set_status(status_code) + self.finish()