Update openapi-core to 0.19.4+

This commit is contained in:
Correl Roush 2024-10-18 17:04:45 -04:00
parent 10de1bd477
commit 5d442073ad
12 changed files with 432 additions and 526 deletions

View file

@ -14,7 +14,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.9"
tornado = "^5 || ^6"
openapi-core = "^0.14.2"
openapi-core = "^0.19.4"
ietfparse = "^1.8.0"
typing-extensions = "^4.0.1"

21
tests/common.py Normal file
View file

@ -0,0 +1,21 @@
import hypothesis.strategies as s
from werkzeug.datastructures import Headers
field_names = s.text(
s.characters(
min_codepoint=33,
max_codepoint=126,
blacklist_categories=["Lu"],
blacklist_characters=":\r\n",
),
min_size=1,
)
field_values = s.text(
s.characters(min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters="; \r\n"),
min_size=1,
)
headers: s.SearchStrategy[Headers] = s.builds(
Headers, s.lists(s.tuples(field_names, field_values))
)

View file

@ -1,14 +1,16 @@
import datetime
import json
import re
import typing
import unittest.mock
from openapi_core.exceptions import OpenAPIError # type: ignore
from openapi_core.exceptions import OpenAPIError
import tornado.httpclient
import tornado.web
import tornado.testing
from tornado_openapi3.handler import OpenAPIRequestHandler
from tornado_openapi3.types import Deserializer, Formatter
class USDateFormatter:
@ -74,23 +76,25 @@ class ResourceHandler(OpenAPIRequestHandler):
},
}
custom_formatters = {
@property
def custom_formatters(self) -> typing.Dict[str, Formatter]:
return {
"usdate": USDateFormatter(),
}
custom_media_type_deserializers = {
@property
def custom_media_type_deserializers(self) -> typing.Dict[str, Deserializer]:
return {
"application/vnd.example.resource+json": json.loads,
}
async def post(self) -> None:
self.set_header("Content-Type", "application/vnd.example.resource+json")
self.finish(
json.dumps(
{
"name": self.validated.body["name"],
}
)
)
body = b""
if isinstance(self.validated.body, dict) and "name" in self.validated.body:
body = json.dumps({"name": self.validated.body["name"]}).encode()
self.finish(body)
class DefaultSchemaTest(tornado.testing.AsyncHTTPTestCase):
@ -102,8 +106,7 @@ class DefaultSchemaTest(tornado.testing.AsyncHTTPTestCase):
with test.assertRaises(NotImplementedError):
self.spec
async def get(self) -> None:
...
async def get(self) -> None: ...
return tornado.web.Application(
[
@ -124,8 +127,7 @@ class DefaultFormatters(tornado.testing.AsyncHTTPTestCase):
async def prepare(self) -> None:
test.assertEqual(dict(), self.custom_formatters)
async def get(self) -> None:
...
async def get(self) -> None: ...
return tornado.web.Application(
[
@ -146,8 +148,7 @@ class DefaultDeserializers(tornado.testing.AsyncHTTPTestCase):
async def prepare(self) -> None:
test.assertEqual(dict(), self.custom_media_type_deserializers)
async def get(self) -> None:
...
async def get(self) -> None: ...
return tornado.web.Application(
[
@ -246,7 +247,7 @@ class RequestHandlerTests(tornado.testing.AsyncHTTPTestCase):
def test_unexpected_openapi_error(self) -> None:
with unittest.mock.patch(
"openapi_core.validation.datatypes.BaseValidationResult.raise_for_errors",
"openapi_core.OpenAPI.unmarshal_request",
side_effect=OpenAPIError,
):
response = self.fetch(

View file

@ -1,257 +1,196 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import dataclasses
import http.cookies
import string
import typing
import unittest
from urllib.parse import urlencode, urlparse
import urllib.parse
import attr
from hypothesis import given
import hypothesis.strategies as s # type: ignore
from openapi_core import create_spec # type: ignore
from openapi_core.exceptions import ( # type: ignore
MissingRequiredParameter,
OpenAPIError,
)
from openapi_core.validation.request.datatypes import ( # type: ignore
RequestParameters,
OpenAPIRequest,
)
from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPHeaders, HTTPServerRequest
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, RequestHandler
from hypothesis import given, provisional
import hypothesis.strategies as s
import openapi_core.datatypes
import openapi_core.protocols
from openapi_core.validation.request.datatypes import RequestParameters
import tornado.httpclient
import tornado.httputil
from werkzeug.datastructures import ImmutableMultiDict
from tornado_openapi3 import RequestValidator, TornadoRequestFactory
import tornado_openapi3.requests
from tests import common
@dataclass
class Parameters:
headers: Dict[str, str]
query_parameters: Dict[str, str]
methods = s.sampled_from(
["get", "head", "post", "put", "delete", "connect", "options", "trace", "patch"]
)
def as_openapi(self) -> List[dict]:
headers = [
{
"name": name.lower(),
"in": "header",
"required": True,
"schema": {"type": "string", "enum": [value]},
}
for name, value in self.headers.items()
]
qargs = [
{
"name": name.lower(),
"in": "query",
"required": True,
"schema": {"type": "string", "enum": [value]},
}
for name, value in self.query_parameters.items()
]
return headers + qargs
field_name = s.text(
s.characters(
min_codepoint=33,
max_codepoint=126,
blacklist_categories=("Lu",),
blacklist_characters=":",
queries: s.SearchStrategy[ImmutableMultiDict[str, str]] = s.builds(
ImmutableMultiDict,
s.lists(
s.tuples(common.field_names, common.field_values),
),
min_size=1,
)
field_value = s.text(
s.characters(min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n"),
cookies: s.SearchStrategy[ImmutableMultiDict[str, str]] = s.builds(
ImmutableMultiDict,
s.dictionaries(
s.text(
alphabet=string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~:",
min_size=1,
),
common.field_values,
),
)
request_parameters = s.builds(
RequestParameters,
query=queries,
header=common.headers,
cookie=cookies,
)
def headers(min_size: int = 0) -> s.SearchStrategy[Dict[str, str]]:
return s.dictionaries(field_name, field_value, min_size=min_size)
def query_parameters(min_size: int = 0) -> s.SearchStrategy[Dict[str, str]]:
return s.dictionaries(field_name, field_value, min_size=min_size)
@dataclasses.dataclass
class TestOpenAPIRequest:
parameters: openapi_core.datatypes.RequestParameters
method: str
body: typing.Optional[bytes]
content_type: str
host_url: str
path: str
@s.composite
def parameters(
draw: Callable[[Any], Any], min_headers: int = 0, min_query_parameters: int = 0
) -> Parameters:
return Parameters(
headers=draw(headers(min_size=min_headers)),
query_parameters=draw(query_parameters(min_size=min_query_parameters)),
def openapi_requests(
draw: typing.Callable[[typing.Any], typing.Any]
) -> openapi_core.protocols.Request:
url = draw(provisional.urls())
parts = urllib.parse.urlparse(url)
content_type = draw(common.field_values)
parameters = draw(request_parameters)
parameters.header["Content-Type"] = content_type
if parameters.cookie:
cookie = http.cookies.SimpleCookie()
for key, value in parameters.cookie.items():
cookie[key] = value
for header in cookie.output(header="").splitlines():
parameters.header.add_header("Cookie", header.strip())
return TestOpenAPIRequest(
parameters=parameters,
method=draw(methods),
body=draw(s.one_of(s.none(), s.binary())),
content_type=content_type,
host_url="{}://{}".format(parts.scheme, parts.netloc),
path=parts.path,
)
class TestRequestFactory(unittest.TestCase):
@given(
s.one_of(
s.tuples(s.just(""), s.just(dict())),
s.tuples(s.just("http://example.com/foo"), query_parameters()),
)
)
def test_http_request(self, opts: Tuple[str, Dict[str, str]]) -> None:
url, parameters = opts
request_url = f"{url}?{urlencode(parameters)}" if url else ""
tornado_request = HTTPRequest(method="GET", url=request_url)
expected = OpenAPIRequest(
full_url_pattern=url,
method="get",
parameters=RequestParameters(query=ImmutableMultiDict(parameters)),
body=None,
mimetype="application/x-www-form-urlencoded",
)
openapi_request = TornadoRequestFactory.create(tornado_request)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request))
@given(
s.one_of(
s.tuples(s.just(""), s.just(dict())),
s.tuples(s.just("http://example.com/foo"), query_parameters()),
)
)
def test_http_server_request(self, opts: Tuple[str, Dict[str, str]]) -> None:
url, parameters = opts
request_url = f"{url}?{urlencode(parameters)}" if url else ""
parsed = urlparse(request_url)
tornado_request = HTTPServerRequest(
method="GET",
uri=f"{parsed.path}?{parsed.query}",
)
tornado_request.protocol = parsed.scheme
tornado_request.host = parsed.netloc.split(":")[0]
expected = OpenAPIRequest(
full_url_pattern=url,
method="get",
parameters=RequestParameters(
query=ImmutableMultiDict(parameters), path={}, cookie={}
),
body=None,
mimetype="application/x-www-form-urlencoded",
)
openapi_request = TornadoRequestFactory.create(tornado_request)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request))
class TestRequest(AsyncHTTPTestCase):
def setUp(self) -> None:
super(TestRequest, self).setUp()
self.request: Optional[HTTPServerRequest] = None
def get_app(self) -> Application:
testcase = self
class TestHandler(RequestHandler):
def get(self) -> None:
nonlocal testcase
testcase.request = self.request
return Application([(r"/.*", TestHandler)])
@given(parameters())
def test_simple_request(self, parameters: Parameters) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/": {
"get": {
"parameters": parameters.as_openapi(),
"responses": {"default": {"description": "Root response"}},
}
}
},
}
)
validator = RequestValidator(spec)
self.fetch(
"/?" + urlencode(parameters.query_parameters),
headers=HTTPHeaders(parameters.headers),
)
assert self.request is not None
result = validator.validate(self.request)
result.raise_for_errors()
@given(parameters(min_headers=1) | parameters(min_query_parameters=1))
def test_simple_request_fails_without_parameters(
self, parameters: Parameters
class RequestTests(unittest.TestCase):
def assertOpenAPIRequestsEqual(
self,
value: openapi_core.protocols.Request,
expected: openapi_core.protocols.Request,
) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/": {
"get": {
"parameters": parameters.as_openapi(),
"responses": {"default": {"description": "Root response"}},
}
}
},
}
self.assertEqual(
value.parameters.query,
expected.parameters.query,
"Query parameters are equal",
)
validator = RequestValidator(spec)
self.fetch("/")
assert self.request is not None
result = validator.validate(self.request)
with self.assertRaises(MissingRequiredParameter):
result.raise_for_errors()
self.assertEqual(
value.parameters.header, expected.parameters.header, "Headers are equal"
)
self.assertEqual(
value.parameters.cookie, expected.parameters.cookie, "Cookies are equal"
)
self.assertEqual(value.method, expected.method, "HTTP methods are equal")
self.assertEqual(value.body, expected.body, "Bodies are equal")
self.assertEqual(
value.content_type, expected.content_type, "Content types are equal"
)
self.assertEqual(value.host_url, expected.host_url, "Host URLs are equal")
self.assertEqual(value.path, expected.path, "Paths are equal")
def test_url_parameters(self) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/{id}": {
"get": {
"parameters": [
{
"name": "id",
"in": "path",
"required": True,
"schema": {"type": "integer"},
}
],
"responses": {"default": {"description": "Root response"}},
}
}
},
}
def url_from_openapi_request(self, request: TestOpenAPIRequest) -> str:
scheme, netloc = request.host_url.split("://")
params = ""
# Preserves multiple values if the parameters are a multidict. This
# whole dance is because ImmutableMultiDict's .items() does not return
# more than one pair per key. Curiously, the Headers structure from the
# same library does.
qsl: typing.List[typing.Tuple[str, str]] = []
query_parameters = ImmutableMultiDict(request.parameters.query)
for key in query_parameters.keys():
for value in query_parameters.getlist(key):
qsl.append((key, value))
query = urllib.parse.urlencode(qsl)
fragment = ""
return urllib.parse.urlunparse(
(
scheme,
netloc,
request.path,
params,
query,
fragment,
)
)
validator = RequestValidator(spec)
self.fetch("/1234")
assert self.request is not None
result = validator.validate(self.request)
result.raise_for_errors()
def test_bad_url_parameters(self) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/{id}": {
"get": {
"parameters": [
{
"name": "id",
"in": "path",
"required": True,
"schema": {"type": "integer"},
}
],
"responses": {"default": {"description": "Root response"}},
}
}
},
}
def tornado_headers_from_openapi_request(
self, request: TestOpenAPIRequest
) -> tornado.httputil.HTTPHeaders:
headers = tornado.httputil.HTTPHeaders()
for key, value in request.parameters.header.items():
headers.add(key, value)
headers["Content-Type"] = request.content_type
if request.parameters.cookie:
cookie = http.cookies.SimpleCookie()
for key, value in request.parameters.cookie.items():
cookie[key] = value
for header in cookie.output(header="").splitlines():
headers.add("Cookie", header.strip())
return headers
def openapi_to_tornado_request(
self, request: TestOpenAPIRequest
) -> tornado.httpclient.HTTPRequest:
url = self.url_from_openapi_request(request)
headers = self.tornado_headers_from_openapi_request(request)
return tornado.httpclient.HTTPRequest(
url,
method=request.method.upper(),
headers=headers,
body=request.body,
)
validator = RequestValidator(spec)
self.fetch("/abcd")
assert self.request is not None
result = validator.validate(self.request)
with self.assertRaises(OpenAPIError):
result.raise_for_errors()
def openapi_to_tornado_server_request(
self, request: TestOpenAPIRequest
) -> tornado.httputil.HTTPServerRequest:
url = self.url_from_openapi_request(request)
headers = self.tornado_headers_from_openapi_request(request)
uri = url.removeprefix(request.host_url)
server_request = tornado.httputil.HTTPServerRequest(
method=request.method.upper(), uri=uri, headers=headers, body=request.body
)
scheme, netloc = request.host_url.split("://")
server_request.protocol = scheme
server_request.host = netloc
return server_request
@given(openapi_requests())
def test_http_request_round_trip_conversion(
self, request: TestOpenAPIRequest
) -> None:
converted = tornado_openapi3.requests.TornadoOpenAPIRequest(
self.openapi_to_tornado_request(request)
)
self.assertOpenAPIRequestsEqual(converted, request)
@given(openapi_requests())
def test_http_server_request_round_trip_conversion(
self, request: TestOpenAPIRequest
) -> None:
# HTTP Server request bodies are not optional
request.body = request.body or b""
converted = tornado_openapi3.requests.TornadoOpenAPIRequest(
self.openapi_to_tornado_server_request(request)
)
self.assertOpenAPIRequestsEqual(converted, request)

View file

@ -1,110 +1,81 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
import dataclasses
import io
import typing
import unittest
import attr
from hypothesis import given
import hypothesis.strategies as s
from openapi_core import create_spec # type: ignore
from openapi_core.validation.response.datatypes import OpenAPIResponse # type: ignore
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, RequestHandler
import openapi_core.protocols
import tornado.httpclient
import tornado.httputil
from werkzeug.datastructures import Headers
from tornado_openapi3 import (
ResponseValidator,
TornadoResponseFactory,
)
import tornado_openapi3.responses
from tests import common
import tornado_openapi3
@dataclass
class Responses:
code: int
headers: Dict[str, str]
def as_openapi(self) -> Dict[str, Any]:
return {
str(self.code): {
"description": "Response",
"headers": {
name: {"schema": {"type": "string", "enum": [value]}}
for name, value in self.headers.items()
},
}
}
@dataclasses.dataclass
class TestOpenAPIResponse:
status_code: int
headers: Headers
content_type: str
data: typing.Optional[bytes]
@s.composite
def responses(draw: Callable[[Any], Any], min_headers: int = 0) -> Responses:
field_name = s.text(
s.characters(
min_codepoint=33,
max_codepoint=126,
blacklist_categories=("Lu",),
blacklist_characters=":",
),
min_size=1,
)
field_value = s.text(
s.characters(
min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n"
),
min_size=1,
)
code = s.sampled_from([200, 304, 400, 500])
headers = s.dictionaries(field_name, field_value, min_size=min_headers)
return Responses(
code=draw(code),
headers=draw(headers),
def openapi_responses(
draw: typing.Callable[[typing.Any], typing.Any]
) -> openapi_core.protocols.Response:
status_code = draw(s.integers(min_value=100, max_value=599))
headers = draw(common.headers)
content_type = draw(common.field_values)
headers["Content-Type"] = content_type
data = draw(s.binary())
return TestOpenAPIResponse(
status_code=status_code,
headers=headers,
content_type=content_type,
data=data,
)
class TestResponseFactory(unittest.TestCase):
def test_response(self) -> None:
tornado_request = HTTPRequest(url="http://example.com")
tornado_response = HTTPResponse(request=tornado_request, code=200)
expected = OpenAPIResponse(
data=b"",
status_code=200,
mimetype="text/html",
class ResponseTests(unittest.TestCase):
def assertOpenAPIResponsesEqual(
self,
value: openapi_core.protocols.Response,
expected: openapi_core.protocols.Response,
) -> None:
self.assertEqual(
value.status_code, expected.status_code, "Status codes are equal"
)
openapi_response = TornadoResponseFactory.create(tornado_response)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_response))
class ResponsesHandler(RequestHandler):
responses: Optional[Responses] = None
def get(self) -> None:
if ResponsesHandler.responses:
self.set_status(ResponsesHandler.responses.code)
for name, value in ResponsesHandler.responses.headers.items():
self.add_header(name, value)
class TestResponse(AsyncHTTPTestCase):
def get_app(self) -> Application:
return Application([(r"/.*", ResponsesHandler)])
@given(responses())
def test_simple_request(self, responses: Responses) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/": {
"get": {
"responses": responses.as_openapi(),
}
}
},
}
self.assertEqual(value.headers, expected.headers, "Headers are equal")
self.assertEqual(
value.content_type, expected.content_type, "Content types are equal"
)
ResponsesHandler.responses = responses
validator = ResponseValidator(spec)
response = self.fetch("/")
result = validator.validate(response)
result.raise_for_errors()
self.assertEqual(value.data, expected.data, "Bodies are equal")
def openapi_to_tornado_response(
self, response: TestOpenAPIResponse
) -> tornado.httpclient.HTTPResponse:
headers = tornado.httputil.HTTPHeaders()
for key, value in response.headers.items():
headers.add(key, value)
return tornado.httpclient.HTTPResponse(
request=tornado.httpclient.HTTPRequest(""),
code=response.status_code,
headers=headers,
buffer=io.BytesIO(response.data or b""),
)
@given(openapi_responses())
def test_http_response_round_trip_conversion(
self, response: TestOpenAPIResponse
) -> None:
converted = tornado_openapi3.responses.TornadoOpenAPIResponse(
self.openapi_to_tornado_response(response)
)
self.assertOpenAPIResponsesEqual(converted, response)

View file

@ -1,11 +1,13 @@
import json
import typing
from openapi_core.templating.responses.exceptions import ( # type: ignore
from openapi_core.templating.responses.exceptions import (
ResponseNotFound,
)
import tornado.web
from tornado_openapi3.handler import OpenAPIRequestHandler
from tornado_openapi3.testing import AsyncOpenAPITestCase
from tornado_openapi3.types import Deserializer
def spec(responses: dict = dict()) -> dict:
@ -37,11 +39,9 @@ def spec(responses: dict = dict()) -> dict:
class TestTestCase(AsyncOpenAPITestCase):
def setUp(self) -> None:
...
def setUp(self) -> None: ...
def tearDown(self) -> None:
...
def tearDown(self) -> None: ...
def test_schema_must_be_implemented(self) -> None:
with self.assertRaises(NotImplementedError):
@ -52,8 +52,13 @@ class TestTestCase(AsyncOpenAPITestCase):
class BaseTestCase(AsyncOpenAPITestCase):
spec_dict = spec()
custom_media_type_deserializers = {
@property
def spec_dict(self) -> dict:
return spec()
@property
def custom_media_type_deserializers(self) -> typing.Dict[str, Deserializer]:
return {
"application/vnd.example.resource+json": json.loads,
}
@ -61,16 +66,22 @@ class BaseTestCase(AsyncOpenAPITestCase):
testcase = self
class ResourceHandler(OpenAPIRequestHandler):
spec = self.spec
custom_media_type_deserializers = self.custom_media_type_deserializers
@property
def spec_dict(self) -> dict:
return spec()
@property
def custom_media_type_deserializers(self) -> typing.Dict[str, Deserializer]:
return {
"application/vnd.example.resource+json": json.loads,
}
async def get(self) -> None:
await testcase.get(self)
return tornado.web.Application([(r"/resource", ResourceHandler)])
async def get(self, handler: tornado.web.RequestHandler) -> None:
...
async def get(self, handler: tornado.web.RequestHandler) -> None: ...
class SuccessTests(BaseTestCase):
@ -97,7 +108,9 @@ class SuccessTests(BaseTestCase):
class IncorrectResponseTests(BaseTestCase):
spec_dict = spec(responses={"200": {"description": "Success"}})
@property
def spec_dict(self) -> dict:
return spec(responses={"200": {"description": "Success"}})
async def get(self, handler: tornado.web.RequestHandler) -> None:
handler.set_status(400)
@ -112,7 +125,9 @@ class IncorrectResponseTests(BaseTestCase):
class RaiseErrorTests(BaseTestCase):
spec_dict = spec(
@property
def spec_dict(self) -> dict:
return spec(
responses={
"500": {
"description": "An error has occurred.",

View file

@ -1,11 +1,11 @@
from tornado_openapi3.handler import OpenAPIRequestHandler
from tornado_openapi3.requests import RequestValidator, TornadoRequestFactory
from tornado_openapi3.responses import ResponseValidator, TornadoResponseFactory
# from tornado_openapi3.handler import OpenAPIRequestHandler
# from tornado_openapi3.requests import RequestValidator, TornadoRequestFactory
# from tornado_openapi3.responses import ResponseValidator, TornadoResponseFactory
__all__ = [
"OpenAPIRequestHandler",
"RequestValidator",
"ResponseValidator",
"TornadoRequestFactory",
"TornadoResponseFactory",
]
# __all__ = [
# "OpenAPIRequestHandler",
# "RequestValidator",
# "ResponseValidator",
# "TornadoRequestFactory",
# "TornadoResponseFactory",
# ]

View file

@ -1,28 +1,25 @@
import asyncio
import logging
from typing import Mapping
import typing
from openapi_core import create_spec # type: ignore
from openapi_core.casting.schemas.exceptions import CastError # type: ignore
from openapi_core.exceptions import ( # type: ignore
MissingRequestBody,
MissingRequiredParameter,
OpenAPIError,
import openapi_core
import openapi_core.validation.request.exceptions
from openapi_core.exceptions import OpenAPIError
from openapi_core.validation.request.exceptions import (
RequestBodyValidationError,
SecurityValidationError,
)
from openapi_core.deserializing.exceptions import DeserializeError # type: ignore
from openapi_core.spec.paths import SpecPath # type: ignore
from openapi_core.templating.media_types.exceptions import ( # type: ignore
from openapi_core.templating.media_types.exceptions import (
MediaTypeNotFound,
)
from openapi_core.templating.paths.exceptions import ( # type: ignore
from openapi_core.templating.paths.exceptions import (
OperationNotFound,
PathNotFound,
)
from openapi_core.unmarshalling.schemas.exceptions import ValidateError # type: ignore
from openapi_core.validation.exceptions import InvalidSecurity # type: ignore
import tornado.web
from tornado_openapi3.requests import RequestValidator
import tornado_openapi3.requests
import tornado_openapi3.types
from tornado_openapi3.types import Deserializer, Formatter
logger = logging.getLogger(__name__)
@ -50,7 +47,7 @@ class OpenAPIRequestHandler(tornado.web.RequestHandler):
raise NotImplementedError()
@property
def spec(self) -> SpecPath:
def spec(self) -> openapi_core.OpenAPI:
"""The OpenAPI 3 specification.
Override this in your request handlers to customize how your OpenAPI 3
@ -59,10 +56,21 @@ class OpenAPIRequestHandler(tornado.web.RequestHandler):
:rtype: :class:`openapi_core.schema.specs.model.Spec`
"""
return create_spec(self.spec_dict, validate_spec=False)
config = openapi_core.Config(
extra_format_unmarshallers={
format: formatter.unmarshal
for format, formatter in self.custom_formatters.items()
},
extra_format_validators={
format: formatter.validate
for format, formatter in self.custom_formatters.items()
},
extra_media_type_deserializers=self.custom_media_type_deserializers,
)
return openapi_core.OpenAPI.from_dict(self.spec_dict, config=config)
@property
def custom_formatters(self) -> Mapping[str, Formatter]:
def custom_formatters(self) -> typing.Dict[str, Formatter]:
"""A dictionary mapping value formats to formatter objects.
If your schemas make use of format modifiers, you may specify them in
@ -76,7 +84,7 @@ class OpenAPIRequestHandler(tornado.web.RequestHandler):
return dict()
@property
def custom_media_type_deserializers(self) -> Mapping[str, Deserializer]:
def custom_media_type_deserializers(self) -> typing.Dict[str, Deserializer]:
"""A dictionary mapping media types to deserializing functions.
If your endpoints make use of content types beyond ``application/json``,
@ -128,31 +136,22 @@ class OpenAPIRequestHandler(tornado.web.RequestHandler):
if maybe_coro and asyncio.iscoroutine(maybe_coro): # pragma: no cover
await maybe_coro
validator = RequestValidator(
self.spec,
custom_formatters=self.custom_formatters,
custom_media_type_deserializers=self.custom_media_type_deserializers,
)
result = validator.validate(self.request)
request = tornado_openapi3.requests.TornadoOpenAPIRequest(self.request)
result = self.spec.unmarshal_request(request)
try:
result.raise_for_errors()
except PathNotFound as e:
self.on_openapi_error(404, e)
except OperationNotFound as e:
self.on_openapi_error(405, e)
except (
CastError,
DeserializeError,
MissingRequiredParameter,
MissingRequestBody,
ValidateError,
) as e:
self.on_openapi_error(400, e)
except InvalidSecurity as e:
self.on_openapi_error(401, e)
except MediaTypeNotFound as e:
except RequestBodyValidationError as e:
if isinstance(e.__cause__, MediaTypeNotFound):
self.on_openapi_error(415, e)
except OpenAPIError as e:
else:
self.on_openapi_error(400, e)
except SecurityValidationError as e:
self.on_openapi_error(401, e)
except OpenAPIError as e: # pragma: no cover
logger.exception("Unexpected validation failure")
self.on_openapi_error(500, e)
self.validated = result

View file

@ -1,75 +1,58 @@
import itertools
import typing
import urllib.parse
from urllib.parse import parse_qsl
from typing import Union
from openapi_core.validation.request.datatypes import ( # type: ignore
OpenAPIRequest,
RequestParameters,
RequestValidationResult,
)
from openapi_core.validation.request import validators # type: ignore
from openapi_core.validation.request.datatypes import RequestParameters
from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest, parse_cookie
from werkzeug.datastructures import ImmutableMultiDict, Headers
from .util import parse_mimetype
class TornadoRequestFactory:
"""Factory for converting Tornado requests to OpenAPI request objects."""
@classmethod
def create(cls, request: Union[HTTPRequest, HTTPServerRequest]) -> OpenAPIRequest:
"""Creates an OpenAPI request from Tornado request objects.
class TornadoOpenAPIRequest:
def __init__(self, request: typing.Union[HTTPRequest, HTTPServerRequest]) -> None:
"""Create an OpenAPI request from Tornado request objects.
Supports both :class:`tornado.httpclient.HTTPRequest` and
:class:`tornado.httputil.HTTPServerRequest` objects.
"""
self.request = request
if isinstance(request, HTTPRequest):
if request.url:
path, _, querystring = request.url.partition("?")
query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict(
parse_qsl(querystring)
)
parts = urllib.parse.urlparse(request.url)
else:
path = ""
query_arguments = ImmutableMultiDict()
else:
path, _, _ = request.full_url().partition("?")
if path == "://":
path = ""
query_arguments = ImmutableMultiDict(
itertools.chain(
*[
[(k, v.decode("utf-8")) for v in vs]
for k, vs in request.query_arguments.items()
]
)
)
return OpenAPIRequest(
full_url_pattern=path,
method=request.method.lower() if request.method else "get",
parameters=RequestParameters(
query=query_arguments,
parts = urllib.parse.urlparse(request.full_url())
protocol = parts.scheme
host = parts.netloc
path = parts.path
query_arguments = parse_qsl(parts.query)
self.protocol = protocol
self.host = host
self.path = path
cookies = {}
for values in request.headers.get_list("Cookie"):
cookies.update(parse_cookie(values))
self.parameters = RequestParameters(
query=ImmutableMultiDict(query_arguments),
header=Headers(request.headers.get_all()),
cookie=parse_cookie(request.headers.get("Cookie", "")),
),
body=request.body if request.body else None,
mimetype=parse_mimetype(
request.headers.get("Content-Type", "application/x-www-form-urlencoded")
),
cookie=ImmutableMultiDict(cookies),
)
self.content_type = request.headers.get(
"Content-Type", "application/x-www-form-urlencoded"
)
@property
def host_url(self) -> str:
return "{}://{}".format(self.protocol, self.host)
class RequestValidator(validators.RequestValidator):
"""Validator for Tornado HTTP Requests."""
@property
def method(self) -> str:
method = self.request.method or "GET"
return method.lower()
def validate(
self, request: Union[HTTPRequest, HTTPServerRequest]
) -> RequestValidationResult:
"""Validate a Tornado HTTP request object."""
return super().validate(TornadoRequestFactory.create(request))
@property
def body(self) -> typing.Optional[bytes]:
return self.request.body
__all__ = ["RequestValidator", "TornadoRequestFactory"]
__all__ = ["TornadoOpenAPIRequest"]

View file

@ -1,37 +1,13 @@
from openapi_core.validation.response.datatypes import ( # type: ignore
OpenAPIResponse,
ResponseValidationResult,
)
from openapi_core.validation.response import validators # type: ignore
from tornado.httpclient import HTTPResponse
from .requests import TornadoRequestFactory
from .util import parse_mimetype
from werkzeug.datastructures import Headers
class TornadoResponseFactory:
"""Factory for converting Tornado responses to OpenAPI response objects."""
@classmethod
def create(cls, response: HTTPResponse) -> OpenAPIResponse:
"""Creates an OpenAPI response from Tornado response objects."""
mimetype = parse_mimetype(response.headers.get("Content-Type", "text/html"))
return OpenAPIResponse(
data=response.body if response.body else b"",
status_code=response.code,
mimetype=mimetype,
)
class TornadoOpenAPIResponse:
def __init__(self, response: HTTPResponse) -> None:
self.status_code = response.code
self.headers = Headers(response.headers.get_all())
self.content_type = response.headers.get("Content-Type", "text/html")
self.data = response.body
class ResponseValidator(validators.ResponseValidator):
"""Validator for Tornado HTTP Responses."""
def validate(self, response: HTTPResponse) -> ResponseValidationResult:
"""Validate a Tornado HTTP response object."""
return super().validate(
TornadoRequestFactory.create(response.request),
TornadoResponseFactory.create(response),
)
__all__ = ["ResponseValidator", "TornadoResponseFactory"]
__all__ = ["TornadoOpenAPIResponse"]

View file

@ -1,11 +1,12 @@
from typing import Any
import typing
import tornado.httpclient
import tornado.testing
import openapi_core
from openapi_core import create_spec # type: ignore
from openapi_core.spec.paths import SpecPath # type: ignore
from tornado_openapi3.responses import ResponseValidator
from tornado_openapi3.requests import TornadoOpenAPIRequest
from tornado_openapi3.responses import TornadoOpenAPIResponse
from tornado_openapi3.types import Deserializer, Formatter
class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
@ -29,7 +30,7 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
raise NotImplementedError()
@property
def spec(self) -> SpecPath:
def spec(self) -> openapi_core.OpenAPI:
"""The OpenAPI 3 specification.
Override this in your test cases to customize how your OpenAPI 3 spec is
@ -38,10 +39,21 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
:rtype: :class:`openapi_core.schema.specs.model.Spec`
"""
return create_spec(self.spec_dict)
config = openapi_core.Config(
extra_format_unmarshallers={
format: formatter.unmarshal
for format, formatter in self.custom_formatters.items()
},
extra_format_validators={
format: formatter.validate
for format, formatter in self.custom_formatters.items()
},
extra_media_type_deserializers=self.custom_media_type_deserializers,
)
return openapi_core.OpenAPI.from_dict(self.spec_dict, config=config)
@property
def custom_formatters(self) -> dict:
def custom_formatters(self) -> typing.Dict[str, Formatter]:
"""A dictionary mapping value formats to formatter objects.
A formatter object must provide:
@ -52,7 +64,7 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
return dict()
@property
def custom_media_type_deserializers(self) -> dict:
def custom_media_type_deserializers(self) -> typing.Dict[str, Deserializer]:
"""A dictionary mapping media types to deserializing functions.
If your endpoints make use of content types beyond ``application/json``,
@ -62,22 +74,8 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
"""
return dict()
def setUp(self) -> None:
"""Hook method for setting up the test fixture before exercising it.
Instantiates the :class:`~tornado_openapi3.responses.ResponseValidator`
for this test case.
"""
super().setUp()
self.validator = ResponseValidator(
self.spec,
custom_formatters=self.custom_formatters,
custom_media_type_deserializers=self.custom_media_type_deserializers,
)
def fetch(
self, path: str, raise_error: bool = False, **kwargs: Any
self, path: str, raise_error: bool = False, **kwargs: typing.Any
) -> tornado.httpclient.HTTPResponse:
"""Convenience methiod to synchronously fetch a URL.
@ -95,7 +93,10 @@ class AsyncOpenAPITestCase(tornado.testing.AsyncHTTPTestCase):
return super().fetch(path, raise_error=raise_error, **kwargs)
response = super().fetch(path, raise_error=False, **kwargs)
result = self.validator.validate(response)
result = self.spec.unmarshal_response(
request=TornadoOpenAPIRequest(response.request),
response=TornadoOpenAPIResponse(response),
)
result.raise_for_errors()
if raise_error:
response.rethrow()

View file

@ -2,7 +2,7 @@ import typing
import typing_extensions
#: A type representing an OpenAPI deserializer.
Deserializer = typing.Callable[[typing.Union[bytes, str]], typing.Any]
Deserializer = typing.Callable[[bytes], typing.Any]
class Formatter(typing_extensions.Protocol):