diff --git a/tests/test_requests.py b/tests/test_requests.py index c85ebec..2d18ea4 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -15,14 +15,13 @@ from openapi_core.validation.request.datatypes import ( # type: ignore RequestParameters, OpenAPIRequest, ) -from openapi_core.validation.request.validators import RequestValidator # type: ignore from tornado.httpclient import HTTPRequest # type: ignore from tornado.httputil import HTTPHeaders, HTTPServerRequest # type: ignore from tornado.testing import AsyncHTTPTestCase # type: ignore from tornado.web import Application, RequestHandler # type: ignore from werkzeug.datastructures import ImmutableMultiDict -from tornado_openapi3 import TornadoRequestFactory +from tornado_openapi3 import RequestValidator, TornadoRequestFactory settings(deadline=None) @@ -55,54 +54,72 @@ class Parameters: return headers + qargs +field_name = s.text( + s.characters( + min_codepoint=33, + max_codepoint=126, + blacklist_categories=("Lu",), + blacklist_characters=":", + ), + min_size=1, +) +field_value = s.text( + s.characters(min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n"), + min_size=1, +) + + +def headers(min_size=0): + return s.dictionaries(field_name, field_value, min_size=min_size) + + +def query_parameters(min_size=0): + return s.dictionaries(field_name, field_value, min_size=min_size) + + @s.composite def parameters(draw, min_headers=0, min_query_parameters=0) -> Parameters: - field_name = s.text( - s.characters( - min_codepoint=33, - max_codepoint=126, - blacklist_categories=("Lu",), - blacklist_characters=":", - ), - min_size=1, - ) - field_value = s.text( - s.characters( - min_codepoint=0x20, max_codepoint=0x7E, blacklist_characters=" \r\n" - ), - min_size=1, - ) return Parameters( - headers=draw(s.dictionaries(field_name, field_value, min_size=min_headers)), - query_parameters=draw( - s.dictionaries(field_name, field_value, min_size=min_query_parameters) - ), + headers=draw(headers(min_size=min_headers)), + query_parameters=draw(query_parameters(min_size=min_query_parameters)), ) class TestRequestFactory(unittest.TestCase): - def test_http_request(self) -> None: - tornado_request = HTTPRequest( - method="GET", url="http://example.com/foo?bar=baz" + @given( + s.one_of( + s.tuples(s.just(""), s.just(dict())), + s.tuples(s.just("http://example.com/foo"), query_parameters()), ) + ) + def test_http_request(self, opts) -> None: + url, parameters = opts + request_url = f"{url}?{urlencode(parameters)}" if url else "" + tornado_request = HTTPRequest(method="GET", url=request_url) expected = OpenAPIRequest( - full_url_pattern="http://example.com/foo", + full_url_pattern=url, method="get", - parameters=RequestParameters(query=ImmutableMultiDict([("bar", "baz")])), + parameters=RequestParameters(query=ImmutableMultiDict(parameters)), body="", mimetype="application/x-www-form-urlencoded", ) openapi_request = TornadoRequestFactory.create(tornado_request) self.assertEqual(attr.asdict(expected), attr.asdict(openapi_request)) - def test_http_server_request(self) -> None: - tornado_request = HTTPServerRequest( - method="GET", uri="http://example.com/foo?bar=baz" + @given( + s.one_of( + s.tuples(s.just(""), s.just(dict())), + s.tuples(s.just("http://example.com/foo"), query_parameters()), ) + ) + def test_http_server_request(self, opts) -> None: + url, parameters = opts + request_url = f"{url}?{urlencode(parameters)}" if url else "" + tornado_request = HTTPServerRequest(method="GET", uri=request_url) expected = OpenAPIRequest( - full_url_pattern="http://example.com/foo", + full_url_pattern=url, method="get", - parameters=RequestParameters(query=ImmutableMultiDict([("bar", "baz")])), + parameters=RequestParameters(query=ImmutableMultiDict(parameters)), body="", mimetype="application/x-www-form-urlencoded", ) @@ -113,7 +130,7 @@ class TestRequestFactory(unittest.TestCase): class TestRequest(AsyncHTTPTestCase): def setUp(self) -> None: super(TestRequest, self).setUp() - self.request: Optional[OpenAPIRequest] = None + self.request: Optional[HTTPServerRequest] = None def get_app(self) -> Application: testcase = self @@ -121,7 +138,7 @@ class TestRequest(AsyncHTTPTestCase): class TestHandler(RequestHandler): def get(self) -> None: nonlocal testcase - testcase.request = TornadoRequestFactory.create(self.request) + testcase.request = self.request return Application([(r"/.*", TestHandler)]) @@ -146,6 +163,7 @@ class TestRequest(AsyncHTTPTestCase): "/?" + urlencode(parameters.query_parameters), headers=HTTPHeaders(parameters.headers), ) + assert self.request is not None result = validator.validate(self.request) result.raise_for_errors() @@ -169,6 +187,7 @@ class TestRequest(AsyncHTTPTestCase): ) validator = RequestValidator(spec) self.fetch("/") + assert self.request is not None result = validator.validate(self.request) with self.assertRaises(MissingRequiredParameter): result.raise_for_errors() @@ -197,6 +216,7 @@ class TestRequest(AsyncHTTPTestCase): ) validator = RequestValidator(spec) self.fetch("/1234") + assert self.request is not None result = validator.validate(self.request) result.raise_for_errors() @@ -224,6 +244,7 @@ class TestRequest(AsyncHTTPTestCase): ) validator = RequestValidator(spec) self.fetch("/abcd") + assert self.request is not None result = validator.validate(self.request) with self.assertRaises(OpenAPIError): result.raise_for_errors()