from unittest import mock import asyncio import contextlib import distutils.dist import distutils.errors import logging import os import json import time import unittest import uuid import warnings from tornado import concurrent, httpserver, httputil, ioloop, testing, web import sprockets.http.app import sprockets.http.mixins import sprockets.http.runner import sprockets.http.testing import examples class RecordingHandler(logging.Handler): def __init__(self): super().__init__() self.emitted = [] def emit(self, record): self.emitted.append((record, self.format(record))) class RaisingHandler(sprockets.http.mixins.ErrorLogger, sprockets.http.mixins.ErrorWriter, web.RequestHandler): def get(self, status_code): raise web.HTTPError(int(status_code), reason=self.get_query_argument('reason', None)) class MockHelper(unittest.TestCase): def setUp(self): super().setUp() self._mocks = [] def tearDown(self): super().tearDown() for mocker in self._mocks: mocker.stop() del self._mocks[:] def start_mock(self, target, existing_mock=None): target_mock = mock.Mock() if existing_mock is None else existing_mock mocked = mock.patch(target, target_mock) self._mocks.append(mocked) return mocked.start() @contextlib.contextmanager def override_environment_variable(**env_vars): stash = {} for name, value in env_vars.items(): stash[name] = os.environ.pop(name, None) if value is not None: os.environ[name] = value try: yield finally: for name, value in stash.items(): os.environ.pop(name, None) if value is not None: os.environ[name] = value class ErrorLoggerTests(testing.AsyncHTTPTestCase): def setUp(self): super().setUp() self.recorder = RecordingHandler() root_logger = logging.getLogger() root_logger.addHandler(self.recorder) def tearDown(self): super().tearDown() logging.getLogger().removeHandler(self.recorder) def get_app(self): return web.Application([ web.url(r'/status/(?P\d+)', examples.StatusHandler), web.url(r'/fail/(?P\d+)', RaisingHandler), ]) def assert_message_logged(self, level, msg_fmt, *msg_args): suffix = msg_fmt.format(*msg_args) for record, message in self.recorder.emitted: if record.levelno == level and message.endswith(suffix): return self.fail('Expected message ending in "%s" to be logged in %r' % (suffix, self.recorder.emitted)) def test_that_client_error_logged_as_warning(self): self.fetch('/status/400') self.assert_message_logged( logging.WARNING, 'failed with 400: {}', httputil.responses[400]) def test_that_server_error_logged_as_error(self): self.fetch('/status/500') self.assert_message_logged( logging.ERROR, 'failed with 500: {}', httputil.responses[500]) def test_that_custom_status_codes_logged_as_unknown(self): self.fetch('/status/623') self.assert_message_logged(logging.ERROR, 'failed with 623: Unknown') def test_that_custom_reasons_are_supported(self): self.fetch('/status/456?reason=oops') self.assert_message_logged(logging.WARNING, 'failed with 456: oops') def test_that_status_code_extracted_from_http_errors(self): self.fetch('/fail/400') self.assert_message_logged( logging.WARNING, 'failed with 400: {}', httputil.responses[400]) def test_that_reason_extracted_from_http_errors(self): self.fetch('/fail/400?reason=oopsie') self.assert_message_logged(logging.WARNING, 'failed with 400: oopsie') def test_that_log_message_is_honored(self): self.fetch('/status/400?log_message=injected%20message') self.assert_message_logged(logging.WARNING, 'failed with 400: injected message') class ErrorWriterTests(testing.AsyncHTTPTestCase): def setUp(self): self._application = None super().setUp() @property def application(self): if self._application is None: self._application = web.Application([ web.url(r'/status/(?P\d+)', examples.StatusHandler), web.url(r'/fail/(?P\d+)', RaisingHandler), ]) return self._application def get_app(self): return self.application def _decode_response(self, response): content_type = response.headers['Content-Type'] self.assertTrue(content_type.startswith('application/json'), 'Incorrect content type received') return json.loads(response.body.decode('utf-8')) def test_that_error_json_contains_error_type(self): response = self.fetch('/fail/400') self.assertEqual(response.code, 400) exc = web.HTTPError(400) body = self._decode_response(response) self.assertEqual(body['type'], exc.__class__.__name__) def test_that_error_json_contains_error_message(self): response = self.fetch('/fail/400') self.assertEqual(response.code, 400) exc = web.HTTPError(400) body = self._decode_response(response) self.assertEqual(body['message'], str(exc)) def test_that_error_json_ignores_the_log_message(self): response = self.fetch('/status/500?log_message=something%20good') self.assertEqual(response.code, 500) body = self._decode_response(response) self.assertEqual(body['message'], httputil.responses[500]) def test_that_error_json_contains_type_none_for_non_exceptions(self): response = self.fetch('/status/500') self.assertEqual(response.code, 500) body = self._decode_response(response) self.assertIsNone(body['type']) def test_that_error_json_contains_reason_for_non_exceptions(self): response = self.fetch('/status/500') self.assertEqual(response.code, 500) body = self._decode_response(response) self.assertEqual(body['message'], httputil.responses[500]) def test_that_error_json_reason_contains_unknown_in_some_cases(self): response = self.fetch('/status/567') self.assertEqual(response.code, 567) body = self._decode_response(response) self.assertEqual(body['message'], 'Unknown') def test_that_error_json_honors_serve_traceback(self): self.application.settings['serve_traceback'] = True response = self.fetch('/fail/400') self.assertEqual(response.code, 400) body = self._decode_response(response) self.assertGreater(len(body['traceback']), 0) def test_that_mediatype_mixin_is_honored(self): send_response = mock.Mock() setattr(examples.StatusHandler, 'send_response', send_response) response = self.fetch('/status/500') self.assertEqual(response.code, 500) send_response.assert_called_once_with({ 'type': None, 'message': 'Internal Server Error', 'traceback': None }) delattr(examples.StatusHandler, 'send_response') class RunTests(MockHelper, unittest.TestCase): def setUp(self): super().setUp() self.runner_cls = self.start_mock('sprockets.http.runner.Runner') self.get_logging_config = self.start_mock( 'sprockets.http._get_logging_config') self.get_logging_config.return_value = {'version': 1} self.logging_dict_config = self.start_mock( 'sprockets.http.logging.config').dictConfig @property def runner_instance(self): return self.runner_cls.return_value def test_that_runner_run_called_with_created_application(self): create_app = mock.Mock() sprockets.http.run(create_app) self.assertEqual(create_app.call_count, 1) self.runner_cls.assert_called_once_with(create_app.return_value) def test_that_debug_envvar_enables_debug_flag(self): create_app = mock.Mock() with override_environment_variable(DEBUG='1'): sprockets.http.run(create_app) create_app.assert_called_once_with(debug=True) self.get_logging_config.assert_called_once_with(True) def test_that_false_debug_envvar_disables_debug_flag(self): create_app = mock.Mock() with override_environment_variable(DEBUG='0'): sprockets.http.run(create_app) create_app.assert_called_once_with(debug=False) self.get_logging_config.assert_called_once_with(False) def test_that_unset_debug_envvar_disables_debug_flag(self): create_app = mock.Mock() with override_environment_variable(DEBUG=None): sprockets.http.run(create_app) create_app.assert_called_once_with(debug=False) self.get_logging_config.assert_called_once_with(False) def test_that_port_defaults_to_8000(self): sprockets.http.run(mock.Mock()) self.runner_instance.run.assert_called_once_with(8000, mock.ANY) def test_that_port_envvar_sets_port_number(self): with override_environment_variable(PORT='8888'): sprockets.http.run(mock.Mock()) self.runner_instance.run.assert_called_once_with(8888, mock.ANY) def test_that_port_kwarg_sets_port_number(self): sprockets.http.run(mock.Mock(), settings={'port': 8888}) self.runner_instance.run.assert_called_once_with(8888, mock.ANY) def test_that_number_of_procs_defaults_to_zero(self): sprockets.http.run(mock.Mock()) self.runner_instance.run.assert_called_once_with(mock.ANY, 0) def test_that_number_of_process_kwarg_sets_number_of_procs(self): sprockets.http.run(mock.Mock(), settings={'number_of_procs': 1}) self.runner_instance.run.assert_called_once_with(mock.ANY, 1) def test_that_logging_dict_config_is_called_appropriately(self): sprockets.http.run(mock.Mock()) self.logging_dict_config.assert_called_once_with( self.get_logging_config.return_value) def test_that_logconfig_override_is_used(self): sprockets.http.run(mock.Mock(), log_config=mock.sentinel.config) self.logging_dict_config.assert_called_once_with( mock.sentinel.config) def test_that_not_specifying_logging_config_is_deprecated(self): with warnings.catch_warnings(record=True) as captured: warnings.simplefilter('always') sprockets.http.run(mock.Mock()) self.assertEqual(len(captured), 1) self.assertTrue(issubclass(captured[0].category, DeprecationWarning)) class CallbackTests(MockHelper, unittest.TestCase): def setUp(self): super().setUp() self.shutdown_callback = mock.Mock() self.before_run_callback = mock.Mock() self.application = self.make_application() self.io_loop = mock.Mock(_callbacks=[], _timeouts=[]) self.io_loop.time.side_effect = time.time ioloop_module = self.start_mock('sprockets.http.runner.ioloop') ioloop_module.IOLoop.instance.return_value = self.io_loop self.start_mock('sprockets.http.runner.httpserver') def make_application(self, **settings): application = mock.Mock() application.settings = settings.copy() application.runner_callbacks = { 'before_run': [self.before_run_callback], 'shutdown': [self.shutdown_callback], } return application def test_that_shutdown_callback_invoked(self): runner = sprockets.http.runner.Runner(self.application) runner.run(8080) runner._shutdown() self.shutdown_callback.assert_called_once_with(self.application) def test_that_exceptions_from_shutdown_callbacks_are_ignored(self): another_callback = mock.Mock() self.application.runner_callbacks['shutdown'].append(another_callback) self.shutdown_callback.side_effect = Exception runner = sprockets.http.runner.Runner(self.application) runner.run(8080) runner._shutdown() self.shutdown_callback.assert_called_once_with(self.application) another_callback.assert_called_once_with(self.application) def test_that_before_run_callback_invoked(self): runner = sprockets.http.runner.Runner(self.application) runner.run(8080) self.before_run_callback.assert_called_once_with(self.application, self.io_loop) def test_that_exceptions_from_before_run_callbacks_are_terminal(self): another_callback = mock.Mock() self.application.runner_callbacks['before_run'].append( another_callback) self.before_run_callback.side_effect = Exception sys_exit = mock.Mock() sys_exit.side_effect = SystemExit with mock.patch('sprockets.http.runner.sys') as sys_module: sys_module.exit = sys_exit with self.assertRaises(SystemExit): runner = sprockets.http.runner.Runner(self.application) runner.run(8080) self.before_run_callback.assert_called_once_with(self.application, self.io_loop) another_callback.assert_not_called() self.shutdown_callback.assert_called_once_with(self.application) sys_exit.assert_called_once_with(70) class RunnerTests(MockHelper, unittest.TestCase): def setUp(self): super().setUp() self.application = mock.Mock() self.application.settings = { 'xheaders': True, 'max_body_size': 2048, 'max_buffer_size': 1024 } self.application.runner_callbacks = {} self.io_loop = mock.Mock() self.io_loop._callbacks = [] self.io_loop._timeouts = [] self.io_loop.time = time.time ioloop_module = self.start_mock('sprockets.http.runner.ioloop') ioloop_module.IOLoop.instance.return_value = self.io_loop self.http_server = mock.Mock(spec=httpserver.HTTPServer) self.httpserver_module = \ self.start_mock('sprockets.http.runner.httpserver') self.httpserver_module.HTTPServer.return_value = self.http_server def test_that_run_starts_ioloop(self): runner = sprockets.http.runner.Runner(self.application) runner.run(8000) self.io_loop.start.assert_called_once_with() def test_that_http_server_settings_are_used(self): runner = sprockets.http.runner.Runner(self.application) runner.run(8000) self.httpserver_module.HTTPServer.assert_called_once_with( self.application, **self.application.settings) def test_that_production_run_starts_in_multiprocess_mode(self): runner = sprockets.http.runner.Runner(self.application) runner.run(8000) self.assertTrue(self.http_server.bind.called) args, kwargs = self.http_server.bind.call_args_list[0] self.assertEqual(args, (8000, )) self.http_server.start.assert_called_once_with(0) def test_that_production_enables_reuse_port(self): runner = sprockets.http.runner.Runner(self.application) runner.run(8000) self.assertTrue(self.http_server.bind.called) args, kwargs = self.http_server.bind.call_args_list[0] self.assertEqual(args, (8000, )) self.assertEqual(kwargs['reuse_port'], True) def test_that_debug_run_starts_in_singleprocess_mode(self): self.application.settings['debug'] = True runner = sprockets.http.runner.Runner(self.application) runner.run(8000) self.http_server.listen.assert_called_once_with(8000) self.http_server.start.assert_not_called() def test_that_initializer_creates_runner_callbacks_dict(self): application = web.Application() sprockets.http.runner.Runner(application) self.assertEqual(application.runner_callbacks['before_run'], []) self.assertEqual(application.runner_callbacks['on_start'], []) self.assertEqual(application.runner_callbacks['shutdown'], []) def test_that_signal_handler_invokes_shutdown(self): with mock.patch('sprockets.http.runner.signal') as signal_module: runner = sprockets.http.runner.Runner(self.application) runner.run(8000) signal_module.signal.assert_any_call(signal_module.SIGINT, runner._on_signal) signal_module.signal.assert_any_call(signal_module.SIGTERM, runner._on_signal) runner._on_signal(signal_module.SIGINT, mock.Mock()) self.io_loop.add_callback_from_signal.assert_called_once_with( runner._shutdown) def test_that_shutdown_stops_after_timelimit(self): def add_timeout(_, callback): time.sleep(0.1) callback() self.io_loop.add_timeout = mock.Mock(side_effect=add_timeout) self.io_loop._timeouts = [mock.Mock()] runner = sprockets.http.runner.Runner(self.application) runner.shutdown_limit = 0.25 runner.wait_timeout = 0.05 runner.run(8000) runner._shutdown() self.io_loop.stop.assert_called_once_with() self.assertNotEqual(self.io_loop._timeouts, []) class AsyncRunTests(unittest.TestCase): def test_that_on_start_callbacks_are_invoked(self): future = concurrent.Future() def on_started(*args, **kwargs): with mock.patch('sprockets.http.runner.Runner.stop_server'): runner._shutdown() future.set_result(True) application = web.Application() with mock.patch('sprockets.http.runner.Runner.start_server'): runner = sprockets.http.runner.Runner(application, on_start=[on_started]) runner.run(8000) self.assertTrue(future.result()) def test_that_shutdown_futures_are_waited_on(self): future = concurrent.Future() def on_started(*args, **kwargs): with mock.patch('sprockets.http.runner.Runner.stop_server'): runner._shutdown() def on_shutdown(*args, **kwargs): def shutdown_complete(): future.set_result(True) ioloop.IOLoop.current().add_timeout(1, shutdown_complete) return future application = web.Application() with mock.patch('sprockets.http.runner.Runner.start_server'): runner = sprockets.http.runner.Runner(application, on_start=[on_started], shutdown=[on_shutdown]) runner.run(8000) self.assertTrue(future.result()) class RunCommandTests(MockHelper, unittest.TestCase): def setUp(self): super().setUp() self.distribution = mock.Mock(spec=distutils.dist.Distribution, verbose=3) def test_that_environment_file_is_processed(self): os_module = self.start_mock('sprockets.http.runner.os') os_module.environ = {'SHOULD_BE': 'REMOVED'} os_module.path.exists.return_value = True open_mock = mock.mock_open(read_data='\n'.join([ 'export SIMPLE=1', 'NOT_EXPORTED=2 # with comment too!', 'export DQUOTED="value with space"', "export SQUOTED='value with space'", 'BAD LINE', '# commented line', 'SHOULD_BE=', ])) self.start_mock('builtins.open', open_mock) command = sprockets.http.runner.RunCommand(self.distribution) command.dry_run = True command._find_callable = mock.Mock() command.env_file = 'name.conf' command.application = 'required.to:exist' command.ensure_finalized() command.run() os_module.path.exists.assert_called_once_with('name.conf') self.assertEqual( sorted(list(os_module.environ.keys())), sorted(['SIMPLE', 'NOT_EXPORTED', 'DQUOTED', 'SQUOTED'])) self.assertEqual(os_module.environ['SIMPLE'], '1') self.assertEqual(os_module.environ['NOT_EXPORTED'], '2') self.assertEqual(os_module.environ['DQUOTED'], 'value with space') self.assertEqual(os_module.environ['SQUOTED'], 'value with space') def test_that_port_option_sets_environment_variable(self): os_module = self.start_mock('sprockets.http.runner.os') os_module.environ = {} os_module.path.exists.return_value = True open_mock = mock.mock_open(read_data='PORT=2') self.start_mock('builtins.open', open_mock) command = sprockets.http.runner.RunCommand(self.distribution) command.dry_run = True command._find_callable = mock.Mock() command.env_file = 'name.conf' command.application = 'required.to:exist' command.port = '3' command.ensure_finalized() command.run() self.assertEqual(os_module.environ['PORT'], '3') def test_that_application_callable_is_created(self): # this is somewhat less hacky than patching __import__ ... # just add a "recorder" around the _find_callable method # in a not so hacky way command = sprockets.http.runner.RunCommand(self.distribution) result_closure = {'real_method': command._find_callable} def patched(): result_closure['result'] = result_closure['real_method']() return result_closure['result'] command.dry_run = True command.application = 'sprockets.http.runner:Runner' command._find_callable = patched command.ensure_finalized() command.run() self.assertEqual(result_closure['result'], sprockets.http.runner.Runner) def test_that_finalize_options_requires_application_option(self): command = sprockets.http.runner.RunCommand(self.distribution) command.env_file = 'not used here' with self.assertRaises(distutils.errors.DistutilsArgError): command.ensure_finalized() def test_that_finalize_options_with_nonexistant_env_file_fails(self): os_module = self.start_mock('sprockets.http.runner.os') os_module.path.exists.return_value = False command = sprockets.http.runner.RunCommand(self.distribution) command.application = examples.Application command.env_file = 'file.conf' with self.assertRaises(distutils.errors.DistutilsArgError): command.ensure_finalized() os_module.path.exists.assert_called_once_with('file.conf') def test_that_sprockets_http_run_is_called_appropriately(self): # yes this god awful path is actually correct :/ run_function = self.start_mock( 'sprockets.http.runner.sprockets.http.run') command = sprockets.http.runner.RunCommand(self.distribution) result_closure = {'real_method': command._find_callable} def patched(): result_closure['result'] = result_closure['real_method']() return result_closure['result'] command.application = 'examples:Application' command.dry_run = False command._find_callable = patched command.ensure_finalized() command.run() run_function.assert_called_once_with(result_closure['result']) class TestCaseTests(unittest.TestCase): class FakeTest(sprockets.http.testing.SprocketsHttpTestCase): def get_app(self): self.app = mock.Mock() return self.app def runTest(self): pass def test_that_setup_calls_start(self): test_case = self.FakeTest() test_case.setUp() test_case.app.start.assert_called_once_with(test_case.io_loop) def test_that_teardown_calls_stop(self): test_case = self.FakeTest() test_case.setUp() test_case.io_loop = mock.Mock() test_case.tearDown() test_case.app.stop.assert_called_once_with( test_case.io_loop, test_case.shutdown_limit, test_case.wait_timeout) class CorrelationFilterTests(unittest.TestCase): def setUp(self): super(CorrelationFilterTests, self).setUp() self.logger = logging.getLogger() self.record = self.logger.makeRecord( 'name', logging.INFO, 'functionName', 42, 'hello %s', tuple(['world']), (None, None, None)) self.filter = sprockets.http._CorrelationFilter() def test_that_correlation_filter_adds_correlation_id(self): self.filter.filter(self.record) self.assertTrue(hasattr(self.record, 'correlation-id')) def test_that_correlation_filter_does_not_overwrite_correlation_id(self): some_value = str(uuid.uuid4()) setattr(self.record, 'correlation-id', some_value) self.filter.filter(self.record) self.assertEqual(getattr(self.record, 'correlation-id'), some_value) class LoggingConfigurationTests(unittest.TestCase): def test_that_debug_sets_log_level_to_debug(self): config = sprockets.http._get_logging_config(True) self.assertEqual(config['root']['level'], 'DEBUG') def test_that_not_debug_sets_log_level_to_info(self): config = sprockets.http._get_logging_config(False) self.assertEqual(config['root']['level'], 'INFO') def test_that_format_includes_sd_when_service_and_env_are_set(self): with override_environment_variable(SERVICE='service', ENVIRONMENT='whatever'): config = sprockets.http._get_logging_config(False) fmt_name = list(config['formatters'].keys())[0] self.assertIn('service="service" environment="whatever"', config['formatters'][fmt_name]['format']) class ShutdownHandlerTests(unittest.TestCase): def setUp(self): super(ShutdownHandlerTests, self).setUp() self.io_loop = ioloop.IOLoop.current() def test_that_on_future_complete_logs_exceptions_from_future(self): future = concurrent.Future() future.set_exception(Exception('Injected Failure')) handler = sprockets.http.app._ShutdownHandler(self.io_loop, 0.2, 0.05) with self.assertLogs(handler.logger, 'WARNING') as cm: handler.on_shutdown_future_complete(future) self.assertEqual(len(cm.output), 1) self.assertIn('Injected Failure', cm.output[0]) def test_that_on_future_complete_logs_active_exceptions(self): future = concurrent.Future() future.set_exception(Exception('Injected Failure')) handler = sprockets.http.app._ShutdownHandler(self.io_loop, 0.2, 0.05) with self.assertLogs(handler.logger, 'WARNING') as cm: try: future.result() except Exception: handler.on_shutdown_future_complete(future) self.assertEqual(len(cm.output), 1) self.assertIn('Injected Failure', cm.output[0]) def test_that_maybe_stop_retries_until_tasks_are_complete(self): async def f(): pass fake_loop = unittest.mock.Mock() fake_loop.time.return_value = 10 loop = asyncio.get_event_loop() tasks = [loop.create_task(f()) for _ in range(5)] handler = sprockets.http.app._ShutdownHandler(fake_loop, 5.0, 0.0) handler.on_shutdown_ready() # sets __deadline to 15 fake_loop.add_timeout.reset_mock() while tasks: task = tasks.pop() handler._maybe_stop() fake_loop.add_timeout.assert_called_once_with( unittest.mock.ANY, handler._maybe_stop) fake_loop.add_timeout.reset_mock() loop.run_until_complete(task) del task handler._maybe_stop() fake_loop.stop.assert_called_once_with() def test_that_maybe_stop_terminates_when_deadline_reached(self): fake_loop = unittest.mock.Mock() fake_loop.time.return_value = 10 loop = asyncio.get_event_loop() loop.create_task(asyncio.sleep(10)) handler = sprockets.http.app._ShutdownHandler(fake_loop, 5.0, 0.0) handler.on_shutdown_ready() # sets __deadline to 15 fake_loop.add_timeout.reset_mock() while fake_loop.time.return_value < 15: handler._maybe_stop() fake_loop.add_timeout.assert_called_once_with( unittest.mock.ANY, handler._maybe_stop) fake_loop.add_timeout.reset_mock() fake_loop.time.return_value += 1 handler._maybe_stop() fake_loop.stop.assert_called_once_with() self.assertEqual(len(asyncio.Task.all_tasks(loop)), 1)