Rewrite tests for sprockets.logging.tornado_log_function

This exposes the defect that we found in production
This commit is contained in:
Dave Shawley 2015-09-14 07:27:47 -04:00
parent 16a57b6484
commit c1349ba8b2

238
tests.py
View file

@ -1,172 +1,128 @@
import json
import logging
import os
import random
import unittest
import uuid
import mock
import sprockets.logging
from tornado import web, testing
LOGGER = logging.getLogger(__name__)
os.environ['ENVIRONMENT'] = 'testing'
class Prototype(object):
pass
import sprockets.logging
class RecordingHandler(logging.FileHandler):
def __init__(self):
logging.FileHandler.__init__(self, filename='/dev/null')
self.log_lines = []
def format(self, record):
log_line = logging.FileHandler.format(self, record)
self.log_lines.append(log_line)
return log_line
def setup_module():
os.environ.setdefault('ENVIRONMENT', 'development')
class ContextFilterTests(unittest.TestCase):
class SimpleHandler(web.RequestHandler):
def setUp(self):
super(ContextFilterTests, self).setUp()
self.logger = logging.getLogger(uuid.uuid4().hex)
self.handler = RecordingHandler()
self.logger.addHandler(self.handler)
def test_that_filter_blocks_key_errors(self):
formatter = logging.Formatter('%(message)s [%(context)s]')
self.handler.setFormatter(formatter)
self.handler.addFilter(sprockets.logging.ContextFilter(
properties=['context']))
self.logger.info('hi there')
def test_that_filter_does_not_overwrite_extras(self):
formatter = logging.Formatter('%(message)s [%(context)s]')
self.handler.setFormatter(formatter)
self.handler.addFilter(sprockets.logging.ContextFilter(
properties=['context']))
self.logger.info('hi there', extra={'context': 'foo'})
self.assertEqual(self.handler.log_lines[-1], 'hi there [foo]')
def get(self):
if self.get_query_argument('runtime_error', default=None):
raise RuntimeError(self.get_query_argument('runtime_error'))
if self.get_query_argument('status_code', default=None) is not None:
self.set_status(int(self.get_query_argument('status_code')))
else:
self.set_status(204)
class MockRequest(object):
headers = {'Accept': 'application/msgpack',
'Correlation-ID': str(uuid.uuid4())}
method = 'GET'
path = '/test'
protocol = 'http'
remote_ip = '127.0.0.1'
query_arguments = {'mock': True}
class RecordingHandler(logging.Handler):
def __init__(self):
self.duration = random.randint(10, 200)
super(RecordingHandler, self).__init__()
self.emitted = []
def request_time(self):
return self.duration
def emit(self, record):
self.emitted.append((record, self.format(record)))
class MockHandler(object):
def __init__(self, status_code=200):
self.status_code = status_code
self.request = MockRequest()
def get_status(self):
return self.status_code
class TornadoLogFunctionTestCase(unittest.TestCase):
@mock.patch('tornado.log.access_log')
def test_log_function_return_value(self, access_log):
handler = MockHandler()
expectation = ('', {'correlation_id':
handler.request.headers['Correlation-ID'],
'duration': handler.request.duration * 1000.0,
'headers': handler.request.headers,
'method': handler.request.method,
'path': handler.request.path,
'protocol': handler.request.protocol,
'query_args': handler.request.query_arguments,
'remote_ip': handler.request.remote_ip,
'status_code': handler.status_code,
'environment': os.environ['ENVIRONMENT']})
sprockets.logging.tornado_log_function(handler)
access_log.info.assert_called_once_with(*expectation)
class JSONRequestHandlerTestCase(unittest.TestCase):
class TornadoLoggingTestMixin(object):
def setUp(self):
self.maxDiff = 32768
super(TornadoLoggingTestMixin, self).setUp()
self.access_log = logging.getLogger('tornado.access')
self.app_log = logging.getLogger('tornado.application')
self.gen_log = logging.getLogger('tornado.general')
for logger in (self.access_log, self.app_log, self.gen_log):
logger.disabled = False
def test_log_function_return_value(self):
class LoggingHandler(logging.Handler):
def __init__(self, level):
super(LoggingHandler, self).__init__(level)
self.formatter = sprockets.logging.JSONRequestFormatter()
self.records = []
self.results = []
def handle(self, value):
self.records.append(value)
self.results.append(self.formatter.format(value))
logging_handler = LoggingHandler(logging.INFO)
LOGGER.addHandler(logging_handler)
handler = MockHandler()
args = {'correlation_id':
handler.request.headers['Correlation-ID'],
'duration': handler.request.duration * 1000.0,
'headers': handler.request.headers,
'method': handler.request.method,
'path': handler.request.path,
'protocol': handler.request.protocol,
'query_args': handler.request.query_arguments,
'remote_ip': handler.request.remote_ip,
'status_code': handler.status_code}
LOGGER.info('', args)
result = logging_handler.results.pop(0)
keys = ['line_number', 'file', 'level', 'module', 'name',
'process', 'thread', 'timestamp', 'request']
value = json.loads(result)
for key in keys:
self.assertIn(key, value)
class JSONRequestFormatterTestCase(testing.AsyncHTTPTestCase):
def setUp(self):
super(JSONRequestFormatterTestCase, self).setUp()
self.recorder = RecordingHandler()
self.formatter = sprockets.logging.JSONRequestFormatter()
self.recorder.setFormatter(self.formatter)
web.app_log.addHandler(self.recorder)
root_logger = logging.getLogger()
root_logger.addHandler(self.recorder)
def tearDown(self):
super(JSONRequestFormatterTestCase, self).tearDown()
web.app_log.removeHandler(self.recorder)
super(TornadoLoggingTestMixin, self).tearDown()
logging.getLogger().removeHandler(self.recorder)
class TornadoLogFunctionTests(TornadoLoggingTestMixin,
testing.AsyncHTTPTestCase):
def get_app(self):
class JustFail(web.RequestHandler):
def get(self):
raise RuntimeError('something busted')
return web.Application(
[web.url('/', SimpleHandler)],
log_function=sprockets.logging.tornado_log_function)
return web.Application([web.url('/', JustFail)])
@property
def access_record(self):
for record, _ in self.recorder.emitted:
if record.name == 'tornado.access':
return record
def test_that_things_happen(self):
self.fetch('/')
self.assertEqual(len(self.recorder.log_lines), 1)
def test_that_redirect_logged_as_info(self):
self.fetch('?status_code=303')
self.assertEqual(self.access_record.levelno, logging.INFO)
failure_info = json.loads(self.recorder.log_lines[0])
self.assertEqual(failure_info['traceback']['type'], 'RuntimeError')
self.assertEqual(failure_info['traceback']['message'],
'something busted')
self.assertEqual(len(failure_info['traceback']['stack']), 2)
def test_that_client_error_logged_as_warning(self):
self.fetch('?status_code=400')
self.assertEqual(self.access_record.levelno, logging.WARNING)
def test_that_exception_is_logged_as_error(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.levelno, logging.ERROR)
def test_that_log_includes_correlation_id(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertIn('correlation_id', self.access_record.args)
def test_that_log_includes_duration(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertIn('duration', self.access_record.args)
def test_that_log_includes_headers(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertIn('headers', self.access_record.args)
def test_that_log_includes_method(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.args['method'], 'GET')
def test_that_log_includess_path(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.args['path'], '/')
def test_that_log_includes_protocol(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.args['protocol'], 'http')
def test_that_log_includes_query_arguments(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.args['query_args'],
{'runtime_error': ['something bad happened']})
def test_that_log_includes_remote_ip(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertIn('remote_ip', self.access_record.args)
def test_that_log_includes_status_code(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.args['status_code'], 500)
def test_that_log_includes_environment(self):
self.fetch('/?runtime_error=something%20bad%20happened')
self.assertEqual(self.access_record.args['environment'],
os.environ['ENVIRONMENT'])
def test_that_log_includes_correlation_id_from_header(self):
cid = str(uuid.uuid4())
self.fetch('/?runtime_error=something%20bad%20happened',
headers={'Correlation-ID': cid})
self.assertEqual(self.access_record.args['correlation_id'], cid)