From c1349ba8b2f4efa425147ec7a655fc074260ee18 Mon Sep 17 00:00:00 2001 From: Dave Shawley Date: Mon, 14 Sep 2015 07:27:47 -0400 Subject: [PATCH] Rewrite tests for sprockets.logging.tornado_log_function This exposes the defect that we found in production --- tests.py | 238 +++++++++++++++++++++++-------------------------------- 1 file changed, 97 insertions(+), 141 deletions(-) diff --git a/tests.py b/tests.py index d2aba0e..6fbba6d 100644 --- a/tests.py +++ b/tests.py @@ -1,172 +1,128 @@ import json import logging import os -import random import unittest import uuid -import mock - -import sprockets.logging from tornado import web, testing -LOGGER = logging.getLogger(__name__) -os.environ['ENVIRONMENT'] = 'testing' - -class Prototype(object): - pass +import sprockets.logging -class RecordingHandler(logging.FileHandler): - def __init__(self): - logging.FileHandler.__init__(self, filename='/dev/null') - self.log_lines = [] - - def format(self, record): - log_line = logging.FileHandler.format(self, record) - self.log_lines.append(log_line) - return log_line +def setup_module(): + os.environ.setdefault('ENVIRONMENT', 'development') -class ContextFilterTests(unittest.TestCase): +class SimpleHandler(web.RequestHandler): - def setUp(self): - super(ContextFilterTests, self).setUp() - self.logger = logging.getLogger(uuid.uuid4().hex) - self.handler = RecordingHandler() - self.logger.addHandler(self.handler) - - def test_that_filter_blocks_key_errors(self): - formatter = logging.Formatter('%(message)s [%(context)s]') - self.handler.setFormatter(formatter) - self.handler.addFilter(sprockets.logging.ContextFilter( - properties=['context'])) - self.logger.info('hi there') - - def test_that_filter_does_not_overwrite_extras(self): - formatter = logging.Formatter('%(message)s [%(context)s]') - self.handler.setFormatter(formatter) - self.handler.addFilter(sprockets.logging.ContextFilter( - properties=['context'])) - self.logger.info('hi there', extra={'context': 'foo'}) - self.assertEqual(self.handler.log_lines[-1], 'hi there [foo]') + def get(self): + if self.get_query_argument('runtime_error', default=None): + raise RuntimeError(self.get_query_argument('runtime_error')) + if self.get_query_argument('status_code', default=None) is not None: + self.set_status(int(self.get_query_argument('status_code'))) + else: + self.set_status(204) -class MockRequest(object): - - headers = {'Accept': 'application/msgpack', - 'Correlation-ID': str(uuid.uuid4())} - method = 'GET' - path = '/test' - protocol = 'http' - remote_ip = '127.0.0.1' - query_arguments = {'mock': True} +class RecordingHandler(logging.Handler): def __init__(self): - self.duration = random.randint(10, 200) + super(RecordingHandler, self).__init__() + self.emitted = [] - def request_time(self): - return self.duration + def emit(self, record): + self.emitted.append((record, self.format(record))) -class MockHandler(object): - - def __init__(self, status_code=200): - self.status_code = status_code - self.request = MockRequest() - - def get_status(self): - return self.status_code - - -class TornadoLogFunctionTestCase(unittest.TestCase): - - @mock.patch('tornado.log.access_log') - def test_log_function_return_value(self, access_log): - handler = MockHandler() - expectation = ('', {'correlation_id': - handler.request.headers['Correlation-ID'], - 'duration': handler.request.duration * 1000.0, - 'headers': handler.request.headers, - 'method': handler.request.method, - 'path': handler.request.path, - 'protocol': handler.request.protocol, - 'query_args': handler.request.query_arguments, - 'remote_ip': handler.request.remote_ip, - 'status_code': handler.status_code, - 'environment': os.environ['ENVIRONMENT']}) - sprockets.logging.tornado_log_function(handler) - access_log.info.assert_called_once_with(*expectation) - - - -class JSONRequestHandlerTestCase(unittest.TestCase): +class TornadoLoggingTestMixin(object): def setUp(self): - self.maxDiff = 32768 + super(TornadoLoggingTestMixin, self).setUp() + self.access_log = logging.getLogger('tornado.access') + self.app_log = logging.getLogger('tornado.application') + self.gen_log = logging.getLogger('tornado.general') + for logger in (self.access_log, self.app_log, self.gen_log): + logger.disabled = False - def test_log_function_return_value(self): - class LoggingHandler(logging.Handler): - def __init__(self, level): - super(LoggingHandler, self).__init__(level) - self.formatter = sprockets.logging.JSONRequestFormatter() - self.records = [] - self.results = [] - - def handle(self, value): - self.records.append(value) - self.results.append(self.formatter.format(value)) - - logging_handler = LoggingHandler(logging.INFO) - LOGGER.addHandler(logging_handler) - - handler = MockHandler() - args = {'correlation_id': - handler.request.headers['Correlation-ID'], - 'duration': handler.request.duration * 1000.0, - 'headers': handler.request.headers, - 'method': handler.request.method, - 'path': handler.request.path, - 'protocol': handler.request.protocol, - 'query_args': handler.request.query_arguments, - 'remote_ip': handler.request.remote_ip, - 'status_code': handler.status_code} - - LOGGER.info('', args) - result = logging_handler.results.pop(0) - keys = ['line_number', 'file', 'level', 'module', 'name', - 'process', 'thread', 'timestamp', 'request'] - value = json.loads(result) - for key in keys: - self.assertIn(key, value) - - -class JSONRequestFormatterTestCase(testing.AsyncHTTPTestCase): - - def setUp(self): - super(JSONRequestFormatterTestCase, self).setUp() self.recorder = RecordingHandler() - self.formatter = sprockets.logging.JSONRequestFormatter() - self.recorder.setFormatter(self.formatter) - web.app_log.addHandler(self.recorder) + root_logger = logging.getLogger() + root_logger.addHandler(self.recorder) def tearDown(self): - super(JSONRequestFormatterTestCase, self).tearDown() - web.app_log.removeHandler(self.recorder) + super(TornadoLoggingTestMixin, self).tearDown() + logging.getLogger().removeHandler(self.recorder) + + +class TornadoLogFunctionTests(TornadoLoggingTestMixin, + testing.AsyncHTTPTestCase): def get_app(self): - class JustFail(web.RequestHandler): - def get(self): - raise RuntimeError('something busted') + return web.Application( + [web.url('/', SimpleHandler)], + log_function=sprockets.logging.tornado_log_function) - return web.Application([web.url('/', JustFail)]) + @property + def access_record(self): + for record, _ in self.recorder.emitted: + if record.name == 'tornado.access': + return record - def test_that_things_happen(self): - self.fetch('/') - self.assertEqual(len(self.recorder.log_lines), 1) + def test_that_redirect_logged_as_info(self): + self.fetch('?status_code=303') + self.assertEqual(self.access_record.levelno, logging.INFO) - failure_info = json.loads(self.recorder.log_lines[0]) - self.assertEqual(failure_info['traceback']['type'], 'RuntimeError') - self.assertEqual(failure_info['traceback']['message'], - 'something busted') - self.assertEqual(len(failure_info['traceback']['stack']), 2) + def test_that_client_error_logged_as_warning(self): + self.fetch('?status_code=400') + self.assertEqual(self.access_record.levelno, logging.WARNING) + + def test_that_exception_is_logged_as_error(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.levelno, logging.ERROR) + + def test_that_log_includes_correlation_id(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertIn('correlation_id', self.access_record.args) + + def test_that_log_includes_duration(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertIn('duration', self.access_record.args) + + def test_that_log_includes_headers(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertIn('headers', self.access_record.args) + + def test_that_log_includes_method(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.args['method'], 'GET') + + def test_that_log_includess_path(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.args['path'], '/') + + def test_that_log_includes_protocol(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.args['protocol'], 'http') + + def test_that_log_includes_query_arguments(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.args['query_args'], + {'runtime_error': ['something bad happened']}) + + def test_that_log_includes_remote_ip(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertIn('remote_ip', self.access_record.args) + + def test_that_log_includes_status_code(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.args['status_code'], 500) + + def test_that_log_includes_environment(self): + self.fetch('/?runtime_error=something%20bad%20happened') + self.assertEqual(self.access_record.args['environment'], + os.environ['ENVIRONMENT']) + + def test_that_log_includes_correlation_id_from_header(self): + cid = str(uuid.uuid4()) + self.fetch('/?runtime_error=something%20bad%20happened', + headers={'Correlation-ID': cid}) + self.assertEqual(self.access_record.args['correlation_id'], cid)