tornado-openapi3/tests/test_requests.py
Correl Roush 4102a14883 Leave request and response bodies as binary data
Decoding to UTF-8 text is the responsibility of the content-type loader
used in the validator, and should not be done prior.
2021-01-06 14:50:15 -05:00

255 lines
8.7 KiB
Python

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import unittest
from urllib.parse import urlencode, urlparse
import attr
from hypothesis import given
import hypothesis.strategies as s # type: ignore
from openapi_core import create_spec # type: ignore
from openapi_core.exceptions import OpenAPIError # type: ignore
from openapi_core.schema.parameters.exceptions import ( # type: ignore
MissingRequiredParameter,
)
from openapi_core.validation.request.datatypes import ( # type: ignore
RequestParameters,
OpenAPIRequest,
)
from tornado.httpclient import HTTPRequest # type: ignore
from tornado.httputil import HTTPHeaders, HTTPServerRequest # type: ignore
from tornado.testing import AsyncHTTPTestCase # type: ignore
from tornado.web import Application, RequestHandler # type: ignore
from werkzeug.datastructures import ImmutableMultiDict
from tornado_openapi3 import RequestValidator, TornadoRequestFactory
@dataclass
class Parameters:
headers: Dict[str, str]
query_parameters: Dict[str, str]
def as_openapi(self) -> List[dict]:
headers = [
{
"name": name.lower(),
"in": "header",
"required": True,
"schema": {"type": "string", "enum": [value]},
}
for name, value in self.headers.items()
]
qargs = [
{
"name": name.lower(),
"in": "query",
"required": True,
"schema": {"type": "string", "enum": [value]},
}
for name, value in self.query_parameters.items()
]
return headers + qargs
field_name = s.text(
s.characters(
min_codepoint=33,
max_codepoint=126,
blacklist_categories=("Lu",),
blacklist_characters=":",
),
min_size=1,
)
field_value = s.text(
s.characters(min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n"),
min_size=1,
)
def headers(min_size: int = 0) -> s.SearchStrategy[Dict[str, str]]:
return s.dictionaries(field_name, field_value, min_size=min_size)
def query_parameters(min_size: int = 0) -> s.SearchStrategy[Dict[str, str]]:
return s.dictionaries(field_name, field_value, min_size=min_size)
@s.composite
def parameters(
draw: Callable[[Any], Any], min_headers: int = 0, min_query_parameters: int = 0
) -> Parameters:
return Parameters(
headers=draw(headers(min_size=min_headers)),
query_parameters=draw(query_parameters(min_size=min_query_parameters)),
)
class TestRequestFactory(unittest.TestCase):
@given(
s.one_of(
s.tuples(s.just(""), s.just(dict())),
s.tuples(s.just("http://example.com/foo"), query_parameters()),
)
)
def test_http_request(self, opts: Tuple[str, Dict[str, str]]) -> None:
url, parameters = opts
request_url = f"{url}?{urlencode(parameters)}" if url else ""
tornado_request = HTTPRequest(method="GET", url=request_url)
expected = OpenAPIRequest(
full_url_pattern=url,
method="get",
parameters=RequestParameters(query=ImmutableMultiDict(parameters)),
body=b"",
mimetype="application/x-www-form-urlencoded",
)
openapi_request = TornadoRequestFactory.create(tornado_request)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request))
@given(
s.one_of(
s.tuples(s.just(""), s.just(dict())),
s.tuples(s.just("http://example.com/foo"), query_parameters()),
)
)
def test_http_server_request(self, opts: Tuple[str, Dict[str, str]]) -> None:
url, parameters = opts
request_url = f"{url}?{urlencode(parameters)}" if url else ""
parsed = urlparse(request_url)
tornado_request = HTTPServerRequest(
method="GET",
uri=f"{parsed.path}?{parsed.query}",
)
tornado_request.protocol = parsed.scheme
tornado_request.host = parsed.netloc.split(":")[0]
expected = OpenAPIRequest(
full_url_pattern=url,
method="get",
parameters=RequestParameters(query=ImmutableMultiDict(parameters)),
body=b"",
mimetype="application/x-www-form-urlencoded",
)
openapi_request = TornadoRequestFactory.create(tornado_request)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request))
class TestRequest(AsyncHTTPTestCase):
def setUp(self) -> None:
super(TestRequest, self).setUp()
self.request: Optional[HTTPServerRequest] = None
def get_app(self) -> Application:
testcase = self
class TestHandler(RequestHandler):
def get(self) -> None:
nonlocal testcase
testcase.request = self.request
return Application([(r"/.*", TestHandler)])
@given(parameters())
def test_simple_request(self, parameters: Parameters) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/": {
"get": {
"parameters": parameters.as_openapi(),
"responses": {"default": {"description": "Root response"}},
}
}
},
}
)
validator = RequestValidator(spec)
self.fetch(
"/?" + urlencode(parameters.query_parameters),
headers=HTTPHeaders(parameters.headers),
)
assert self.request is not None
result = validator.validate(self.request)
result.raise_for_errors()
@given(parameters(min_headers=1) | parameters(min_query_parameters=1))
def test_simple_request_fails_without_parameters(
self, parameters: Parameters
) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/": {
"get": {
"parameters": parameters.as_openapi(),
"responses": {"default": {"description": "Root response"}},
}
}
},
}
)
validator = RequestValidator(spec)
self.fetch("/")
assert self.request is not None
result = validator.validate(self.request)
with self.assertRaises(MissingRequiredParameter):
result.raise_for_errors()
def test_url_parameters(self) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/{id}": {
"get": {
"parameters": [
{
"name": "id",
"in": "path",
"required": True,
"schema": {"type": "integer"},
}
],
"responses": {"default": {"description": "Root response"}},
}
}
},
}
)
validator = RequestValidator(spec)
self.fetch("/1234")
assert self.request is not None
result = validator.validate(self.request)
result.raise_for_errors()
def test_bad_url_parameters(self) -> None:
spec = create_spec(
{
"openapi": "3.0.0",
"info": {"title": "Test specification", "version": "0.1"},
"paths": {
"/{id}": {
"get": {
"parameters": [
{
"name": "id",
"in": "path",
"required": True,
"schema": {"type": "integer"},
}
],
"responses": {"default": {"description": "Root response"}},
}
}
},
}
)
validator = RequestValidator(spec)
self.fetch("/abcd")
assert self.request is not None
result = validator.validate(self.request)
with self.assertRaises(OpenAPIError):
result.raise_for_errors()