Add support for tornado.httpclient.HTTPRequest

This commit is contained in:
Correl Roush 2020-11-25 22:23:23 -05:00
parent a5aa5adc9f
commit 53471f3fce
3 changed files with 50 additions and 16 deletions

View file

@ -16,6 +16,7 @@ from openapi_core.validation.request.datatypes import ( # type: ignore
OpenAPIRequest, OpenAPIRequest,
) )
from openapi_core.validation.request.validators import RequestValidator # type: ignore from openapi_core.validation.request.validators import RequestValidator # type: ignore
from tornado.httpclient import HTTPRequest # type: ignore
from tornado.httputil import HTTPHeaders, HTTPServerRequest # type: ignore from tornado.httputil import HTTPHeaders, HTTPServerRequest # type: ignore
from tornado.testing import AsyncHTTPTestCase # type: ignore from tornado.testing import AsyncHTTPTestCase # type: ignore
from tornado.web import Application, RequestHandler # type: ignore from tornado.web import Application, RequestHandler # type: ignore
@ -80,7 +81,21 @@ def parameters(draw, min_headers=0, min_query_parameters=0) -> Parameters:
class TestRequestFactory(unittest.TestCase): class TestRequestFactory(unittest.TestCase):
def test_request(self) -> None: def test_http_request(self) -> None:
tornado_request = HTTPRequest(
method="GET", url="http://example.com/foo?bar=baz"
)
expected = OpenAPIRequest(
full_url_pattern="http://example.com/foo",
method="get",
parameters=RequestParameters(query=ImmutableMultiDict([("bar", "baz")])),
body="",
mimetype="application/x-www-form-urlencoded",
)
openapi_request = TornadoRequestFactory.create(tornado_request)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request))
def test_http_server_request(self) -> None:
tornado_request = HTTPServerRequest( tornado_request = HTTPServerRequest(
method="GET", uri="http://example.com/foo?bar=baz" method="GET", uri="http://example.com/foo?bar=baz"
) )

View file

@ -1,3 +1,6 @@
from tornado_openapi3.requests import TornadoRequestFactory from tornado_openapi3.requests import RequestValidator, TornadoRequestFactory
__all__ = ["TornadoRequestFactory"] __all__ = [
"RequestValidator",
"TornadoRequestFactory",
]

View file

@ -1,22 +1,35 @@
import itertools import itertools
from urllib.parse import parse_qsl
from typing import Union
from openapi_core.validation.request.datatypes import ( # type: ignore from openapi_core.validation.request.datatypes import ( # type: ignore
RequestParameters, RequestParameters,
OpenAPIRequest, OpenAPIRequest,
) )
from openapi_core.validation.request import validators # type: ignore from openapi_core.validation.request import validators # type: ignore
from tornado.httpclient import HTTPRequest # type: ignore
from tornado.httputil import HTTPServerRequest # type: ignore from tornado.httputil import HTTPServerRequest # type: ignore
from werkzeug.datastructures import ImmutableMultiDict, Headers from werkzeug.datastructures import ImmutableMultiDict, Headers
class TornadoRequestFactory: class TornadoRequestFactory:
@classmethod @classmethod
def create(cls, request: HTTPServerRequest) -> OpenAPIRequest: def create(cls, request: Union[HTTPRequest, HTTPServerRequest]) -> OpenAPIRequest:
if isinstance(request, HTTPRequest):
if request.url:
path, _, querystring = request.url.partition("?")
query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict(
parse_qsl(querystring)
)
else:
path = ""
query_arguments = ImmutableMultiDict()
else:
if request.uri: if request.uri:
path, _, _ = request.uri.partition("?") path, _, _ = request.uri.partition("?")
else: else:
path = "" path = ""
query_arguments: ImmutableMultiDict[str, str] = ImmutableMultiDict( query_arguments = ImmutableMultiDict(
itertools.chain( itertools.chain(
*[ *[
[(k, v.decode("utf-8")) for v in vs] [(k, v.decode("utf-8")) for v in vs]
@ -30,7 +43,7 @@ class TornadoRequestFactory:
parameters=RequestParameters( parameters=RequestParameters(
query=query_arguments, header=Headers(request.headers.get_all()) query=query_arguments, header=Headers(request.headers.get_all())
), ),
body=request.body.decode("utf-8"), body=request.body.decode("utf-8") if request.body else "",
mimetype=request.headers.get( mimetype=request.headers.get(
"Content-Type", "application/x-www-form-urlencoded" "Content-Type", "application/x-www-form-urlencoded"
), ),
@ -38,5 +51,8 @@ class TornadoRequestFactory:
class RequestValidator(validators.RequestValidator): class RequestValidator(validators.RequestValidator):
def validate(self, request: HTTPServerRequest): def validate(self, request: Union[HTTPRequest, HTTPServerRequest]):
return super().validate(TornadoRequestFactory.create(request)) return super().validate(TornadoRequestFactory.create(request))
__all__ = ["RequestValidator", "TornadoRequestFactory"]