diff --git a/sprockets_influxdb.py b/sprockets_influxdb.py index 35ee9a5..7a3f276 100644 --- a/sprockets_influxdb.py +++ b/sprockets_influxdb.py @@ -30,6 +30,12 @@ except ImportError: # pragma: no cover logging.critical('Could not import Tornado') concurrent, httpclient, ioloop = None, None, None +try: + from tornado import routing +except ImportError: # Not needed for Tornado<4.5 + pass + + version_info = (2, 1, 0) __version__ = '.'.join(str(v) for v in version_info) __all__ = ['__version__', 'version_info', 'add_measurement', 'flush', @@ -100,15 +106,46 @@ class InfluxDBMixin(object): handler = '{}.{}'.format(self.__module__, self.__class__.__name__) self.influxdb.set_tags({'handler': handler, 'method': request.method}) - for host, handlers in application.handlers: - if not host.match(request.host): - continue + + pattern = None + if hasattr(application, 'handlers'): + pattern = self._get_path_pattern_tornado4() + else: + pattern = self._get_path_pattern_tornado45() + if pattern: + endpoint = pattern.rstrip('$') + else: + LOGGER.warning('Unable to determine routing pattern') + endpoint = request.path + self.influxdb.set_tags({'endpoint': endpoint}) + + def _get_path_pattern_tornado4(self): + """Return the path pattern used when routing a request. (Tornado<4.5) + + :rtype: str + """ + for host, handlers in self.application.handlers: + if host.match(self.request.host): for handler in handlers: - match = handler.regex.match(request.path) - if match: - self.influxdb.set_tag( - 'endpoint', handler.regex.pattern.rstrip('$')) - break + if handler.regex.match(self.request.path): + return handler.regex.pattern + + def _get_path_pattern_tornado45(self, router=None): + """Return the path pattern used when routing a request. (Tornado>=4.5) + + :param tornado.routing.Router router: (Optional) The router to scan. + Defaults to the application's router. + + :rtype: str + """ + if router is None: + router = self.application.default_router + for rule in router.rules: + if rule.matcher.match(self.request) is not None: + if isinstance(rule.matcher, routing.PathMatches): + return rule.matcher.regex.pattern + elif isinstance(rule.target, routing.Router): + return self._get_path_pattern_tornado45(rule.target) def on_finish(self): if _enabled: diff --git a/tests/mixin_tests.py b/tests/mixin_tests.py index 3a03608..360b77f 100644 --- a/tests/mixin_tests.py +++ b/tests/mixin_tests.py @@ -1,5 +1,9 @@ +import mock import socket import time +import unittest + +import tornado from . import base @@ -66,3 +70,65 @@ class MeasurementTestCase(base.AsyncServerTestCase): self.assertEqual(measurement.tags['method'], 'GET') self.assertEqual(measurement.tags['endpoint'], '/param/(?P\d+)') self.assertEqual(measurement.fields['content_length'], 13) + + def test_measurement_with_specific_host(self): + self.application.add_handlers( + 'some_host', [('/host/(?P\d+)', base.ParamRequestHandler)]) + result = self.fetch('/host/100', headers={'Host': 'some_host'}) + self.assertEqual(result.code, 200) + measurement = self.get_measurement() + self.assertIsNotNone(measurement) + self.assertEqual(measurement.db, 'database-name') + self.assertEqual(measurement.name, 'my-service') + self.assertEqual(measurement.tags['status_code'], '200') + self.assertEqual(measurement.tags['method'], 'GET') + self.assertEqual(measurement.tags['endpoint'], '/host/(?P\d+)') + self.assertEqual(measurement.fields['content_length'], 13) + + @unittest.skipIf(tornado.version_info >= (4, 5), + 'legacy routing removed in 4.5') + @mock.patch( + 'sprockets_influxdb.InfluxDBMixin._get_path_pattern_tornado45') + @mock.patch( + 'sprockets_influxdb.InfluxDBMixin._get_path_pattern_tornado4') + def test_mesurement_with_ambiguous_route_4(self, mock_4, mock_45): + mock_4.return_value = None + mock_45.return_value = None + + result = self.fetch('/param/100') + self.assertEqual(result.code, 200) + measurement = self.get_measurement() + self.assertIsNotNone(measurement) + self.assertEqual(measurement.db, 'database-name') + self.assertEqual(measurement.name, 'my-service') + self.assertEqual(measurement.tags['status_code'], '200') + self.assertEqual(measurement.tags['method'], 'GET') + self.assertEqual(measurement.tags['endpoint'], '/param/100') + self.assertEqual(measurement.fields['content_length'], 13) + + self.assertEqual(1, mock_4.call_count) + self.assertEqual(0, mock_45.call_count) + + @unittest.skipIf(tornado.version_info < (4, 5), + 'routing module introduced in tornado 4.5') + @mock.patch( + 'sprockets_influxdb.InfluxDBMixin._get_path_pattern_tornado45') + @mock.patch( + 'sprockets_influxdb.InfluxDBMixin._get_path_pattern_tornado4') + def test_mesurement_with_ambiguous_route_45(self, mock_4, mock_45): + mock_4.return_value = None + mock_45.return_value = None + + result = self.fetch('/param/100') + self.assertEqual(result.code, 200) + measurement = self.get_measurement() + self.assertIsNotNone(measurement) + self.assertEqual(measurement.db, 'database-name') + self.assertEqual(measurement.name, 'my-service') + self.assertEqual(measurement.tags['status_code'], '200') + self.assertEqual(measurement.tags['method'], 'GET') + self.assertEqual(measurement.tags['endpoint'], '/param/100') + self.assertEqual(measurement.fields['content_length'], 13) + + self.assertEqual(0, mock_4.call_count) + self.assertEqual(1, mock_45.call_count)