Add additional request tests

- Cover empty url cases
- Test using the new RequestValidator object
This commit is contained in:
Correl Roush 2020-11-26 22:24:58 -05:00
parent 0b0bb57f2e
commit 8c750bc16c

View file

@ -15,14 +15,13 @@ from openapi_core.validation.request.datatypes import ( # type: ignore
RequestParameters, RequestParameters,
OpenAPIRequest, OpenAPIRequest,
) )
from openapi_core.validation.request.validators import RequestValidator # type: ignore
from tornado.httpclient import HTTPRequest # type: ignore from tornado.httpclient import HTTPRequest # type: ignore
from tornado.httputil import HTTPHeaders, HTTPServerRequest # type: ignore from tornado.httputil import HTTPHeaders, HTTPServerRequest # type: ignore
from tornado.testing import AsyncHTTPTestCase # type: ignore from tornado.testing import AsyncHTTPTestCase # type: ignore
from tornado.web import Application, RequestHandler # type: ignore from tornado.web import Application, RequestHandler # type: ignore
from werkzeug.datastructures import ImmutableMultiDict from werkzeug.datastructures import ImmutableMultiDict
from tornado_openapi3 import TornadoRequestFactory from tornado_openapi3 import RequestValidator, TornadoRequestFactory
settings(deadline=None) settings(deadline=None)
@ -55,54 +54,72 @@ class Parameters:
return headers + qargs return headers + qargs
field_name = s.text(
s.characters(
min_codepoint=33,
max_codepoint=126,
blacklist_categories=("Lu",),
blacklist_characters=":",
),
min_size=1,
)
field_value = s.text(
s.characters(min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n"),
min_size=1,
)
def headers(min_size=0):
return s.dictionaries(field_name, field_value, min_size=min_size)
def query_parameters(min_size=0):
return s.dictionaries(field_name, field_value, min_size=min_size)
@s.composite @s.composite
def parameters(draw, min_headers=0, min_query_parameters=0) -> Parameters: def parameters(draw, min_headers=0, min_query_parameters=0) -> Parameters:
field_name = s.text(
s.characters(
min_codepoint=33,
max_codepoint=126,
blacklist_categories=("Lu",),
blacklist_characters=":",
),
min_size=1,
)
field_value = s.text(
s.characters(
min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n"
),
min_size=1,
)
return Parameters( return Parameters(
headers=draw(s.dictionaries(field_name, field_value, min_size=min_headers)), headers=draw(headers(min_size=min_headers)),
query_parameters=draw( query_parameters=draw(query_parameters(min_size=min_query_parameters)),
s.dictionaries(field_name, field_value, min_size=min_query_parameters)
),
) )
class TestRequestFactory(unittest.TestCase): class TestRequestFactory(unittest.TestCase):
def test_http_request(self) -> None: @given(
tornado_request = HTTPRequest( s.one_of(
method="GET", url="http://example.com/foo?bar=baz" s.tuples(s.just(""), s.just(dict())),
s.tuples(s.just("http://example.com/foo"), query_parameters()),
) )
)
def test_http_request(self, opts) -> None:
url, parameters = opts
request_url = f"{url}?{urlencode(parameters)}" if url else ""
tornado_request = HTTPRequest(method="GET", url=request_url)
expected = OpenAPIRequest( expected = OpenAPIRequest(
full_url_pattern="http://example.com/foo", full_url_pattern=url,
method="get", method="get",
parameters=RequestParameters(query=ImmutableMultiDict([("bar", "baz")])), parameters=RequestParameters(query=ImmutableMultiDict(parameters)),
body="", body="",
mimetype="application/x-www-form-urlencoded", mimetype="application/x-www-form-urlencoded",
) )
openapi_request = TornadoRequestFactory.create(tornado_request) openapi_request = TornadoRequestFactory.create(tornado_request)
self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request)) self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request))
def test_http_server_request(self) -> None: @given(
tornado_request = HTTPServerRequest( s.one_of(
method="GET", uri="http://example.com/foo?bar=baz" s.tuples(s.just(""), s.just(dict())),
s.tuples(s.just("http://example.com/foo"), query_parameters()),
) )
)
def test_http_server_request(self, opts) -> None:
url, parameters = opts
request_url = f"{url}?{urlencode(parameters)}" if url else ""
tornado_request = HTTPServerRequest(method="GET", uri=request_url)
expected = OpenAPIRequest( expected = OpenAPIRequest(
full_url_pattern="http://example.com/foo", full_url_pattern=url,
method="get", method="get",
parameters=RequestParameters(query=ImmutableMultiDict([("bar", "baz")])), parameters=RequestParameters(query=ImmutableMultiDict(parameters)),
body="", body="",
mimetype="application/x-www-form-urlencoded", mimetype="application/x-www-form-urlencoded",
) )
@ -113,7 +130,7 @@ class TestRequestFactory(unittest.TestCase):
class TestRequest(AsyncHTTPTestCase): class TestRequest(AsyncHTTPTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(TestRequest, self).setUp() super(TestRequest, self).setUp()
self.request: Optional[OpenAPIRequest] = None self.request: Optional[HTTPServerRequest] = None
def get_app(self) -> Application: def get_app(self) -> Application:
testcase = self testcase = self
@ -121,7 +138,7 @@ class TestRequest(AsyncHTTPTestCase):
class TestHandler(RequestHandler): class TestHandler(RequestHandler):
def get(self) -> None: def get(self) -> None:
nonlocal testcase nonlocal testcase
testcase.request = TornadoRequestFactory.create(self.request) testcase.request = self.request
return Application([(r"/.*", TestHandler)]) return Application([(r"/.*", TestHandler)])
@ -146,6 +163,7 @@ class TestRequest(AsyncHTTPTestCase):
"/?" + urlencode(parameters.query_parameters), "/?" + urlencode(parameters.query_parameters),
headers=HTTPHeaders(parameters.headers), headers=HTTPHeaders(parameters.headers),
) )
assert self.request is not None
result = validator.validate(self.request) result = validator.validate(self.request)
result.raise_for_errors() result.raise_for_errors()
@ -169,6 +187,7 @@ class TestRequest(AsyncHTTPTestCase):
) )
validator = RequestValidator(spec) validator = RequestValidator(spec)
self.fetch("/") self.fetch("/")
assert self.request is not None
result = validator.validate(self.request) result = validator.validate(self.request)
with self.assertRaises(MissingRequiredParameter): with self.assertRaises(MissingRequiredParameter):
result.raise_for_errors() result.raise_for_errors()
@ -197,6 +216,7 @@ class TestRequest(AsyncHTTPTestCase):
) )
validator = RequestValidator(spec) validator = RequestValidator(spec)
self.fetch("/1234") self.fetch("/1234")
assert self.request is not None
result = validator.validate(self.request) result = validator.validate(self.request)
result.raise_for_errors() result.raise_for_errors()
@ -224,6 +244,7 @@ class TestRequest(AsyncHTTPTestCase):
) )
validator = RequestValidator(spec) validator = RequestValidator(spec)
self.fetch("/abcd") self.fetch("/abcd")
assert self.request is not None
result = validator.validate(self.request) result = validator.validate(self.request)
with self.assertRaises(OpenAPIError): with self.assertRaises(OpenAPIError):
result.raise_for_errors() result.raise_for_errors()