Adding TCP support for statsd

This commit is contained in:
Dan g 2018-07-17 17:01:32 -04:00
parent cbaf1b657f
commit 48629e6a38
3 changed files with 234 additions and 20 deletions

View file

@ -4,6 +4,8 @@ import os
import socket
import time
from tornado import iostream
LOGGER = logging.getLogger(__name__)
SETTINGS_KEY = 'sprockets.mixins.metrics.statsd'
@ -102,6 +104,7 @@ class StatsDCollector(object):
:param str host: The StatsD host
:param str port: The StatsD port
:param bool tcp: Flag to set a TCP or UDP client
:param str namespace: The StatsD bucket to write metrics into.
:param bool prepend_metric_type: Optional flag to prepend bucket path
with the StatsD metric type
@ -110,13 +113,42 @@ class StatsDCollector(object):
METRIC_TYPES = {'c': 'counters',
'ms': 'timers'}
def __init__(self, host, port, namespace='sprockets',
def __init__(self, host, port, proto='udp', namespace='sprockets',
prepend_metric_type=True):
self._host = host
self._port = int(port)
self._address = (self._host, self._port)
self._namespace = namespace
self._prepend_metric_type = prepend_metric_type
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
if proto == 'udp':
self._tcp = False
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
else:
self._tcp = True
self._sock = self._tcp_socket()
def _tcp_socket(self):
"""Connect to statsd via TCP and return the IOStream handle.
:rtype: iostream.IOStream
"""
sock = iostream.IOStream(socket.socket(
socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP))
try:
sock.connect(self._address, self._tcp_on_connected)
except (OSError, socket.error) as error:
LOGGER.error('Failed to connect via TCP: %s', error)
sock.set_close_callback(self._tcp_on_closed)
return sock
def _tcp_on_closed(self):
"""Invoked when the socket is closed."""
LOGGER.warning('Disconnected from statsd, reconnecting')
self._sock = self._tcp_socket()
def _tcp_on_connected(self):
"""Invoked when the IOStream is connected"""
LOGGER.debug('Connected to statsd at %s via TCP', self._address)
def send(self, path, value, metric_type):
"""Send a metric to Statsd.
@ -128,10 +160,17 @@ class StatsDCollector(object):
"""
msg = '{0}:{1}|{2}'.format(
self._build_path(path, metric_type), value, metric_type)
encoding = 'utf-8'
#encoding = 'ascii'
try:
LOGGER.debug('Sending %s to %s:%s', msg.encode('ascii'),
LOGGER.debug('Sending %s to %s:%s', msg.encode(encoding),
self._host, self._port)
self._sock.sendto(msg.encode('ascii'), (self._host, self._port))
if self._tcp:
return self._sock.write(msg.encode(encoding))
else:
self._sock.sendto(msg.encode(encoding), (self._host, self._port))
except socket.error:
LOGGER.exception('Error sending StatsD metrics')
@ -191,5 +230,14 @@ def install(application, **kwargs):
if 'port' not in kwargs:
kwargs['port'] = os.environ.get('STATSD_PORT', '8125')
if 'proto' not in kwargs:
if "STATSD_PROTO" in os.environ:
kwargs['proto'] = os.environ.get('STATSD_PROTO', 'udp')
else:
kwargs['proto'] = 'udp'
if kwargs['proto'] not in ['udp', 'tcp']:
raise ValueError('Invalid value for STATSD_PROTO: {}'.format(kwargs['proto']))
setattr(application, 'statsd', StatsDCollector(**kwargs))
return True

View file

@ -2,13 +2,72 @@ import logging
import re
import socket
from tornado import gen, web
from tornado import gen, iostream, locks, tcpserver, testing, web
LOGGER = logging.getLogger(__name__)
STATS_PATTERN = re.compile(r'(?P<path>[^:]*):(?P<value>[^|]*)\|(?P<type>.*)$')
class FakeStatsdServer(object):
class FakeTCPStatsdServer(tcpserver.TCPServer):
PATTERN = br'(?P<path>[^:]*):(?P<value>[^|]*)\|(?P<type>.*)$'
def __init__(self, iol, ssl_options=None, max_buffer_size=None,
read_chunk_size=None):
self.event = locks.Event()
self.datagrams = []
self.reconnect_receive = False
super(FakeTCPStatsdServer, self).__init__(
ssl_options, max_buffer_size, read_chunk_size)
self.sock, self.port = testing.bind_unused_port()
self.add_socket(self.sock)
self.sockaddr = self.sock.getsockname()
@gen.coroutine
def handle_stream(self, stream, address):
while True:
try:
result = yield stream.read_until_regex(self.PATTERN)
except iostream.StreamClosedError:
break
else:
self.event.set()
self.datagrams.append(result)
if b'reconnect' in result:
self.reconnect_receive = True
stream.close()
return
def find_metrics(self, prefix, metric_type):
"""
Yields captured datagrams that start with `prefix`.
:param str prefix: the metric prefix to search for
:param str metric_type: the statsd metric type (e.g., 'ms', 'c')
:returns: yields (path, value, metric_type) tuples for each
captured metric that matches
:raises AssertionError: if no metrics match.
"""
pattern = re.compile(
'(?P<path>{}[^:]*):(?P<value>[^|]*)\\|(?P<type>{})'.format(
re.escape(prefix), re.escape(metric_type)))
matched = False
for datagram in self.datagrams:
text_msg = datagram.decode('ascii')
match = pattern.match(text_msg)
if match:
yield match.groups()
matched = True
if not matched:
raise AssertionError(
'Expected metric starting with "{}" in {!r}'.format(
prefix, self.datagrams))
class FakeUDPStatsdServer(object):
"""
Implements something resembling a statsd server.
@ -73,6 +132,7 @@ class FakeStatsdServer(object):
'(?P<path>{}[^:]*):(?P<value>[^|]*)\\|(?P<type>{})'.format(
re.escape(prefix), re.escape(metric_type)))
matched = False
for datagram in self.datagrams:
text_msg = datagram.decode('ascii')
match = pattern.match(text_msg)

132
tests.py
View file

@ -11,7 +11,7 @@ import mock
from sprockets.mixins.metrics import influxdb, statsd
from sprockets.mixins.metrics.testing import (
FakeInfluxHandler, FakeStatsdServer)
FakeInfluxHandler, FakeUDPStatsdServer, FakeTCPStatsdServer)
import examples.influxdb
import examples.statsd
@ -41,8 +41,7 @@ def assert_between(low, value, high):
raise AssertionError('Expected {} to be between {} and {}'.format(
value, low, high))
class StatsdMetricCollectionTests(testing.AsyncHTTPTestCase):
class TCPStatsdMetricCollectionTests(testing.AsyncHTTPTestCase):
def get_app(self):
self.application = web.Application([
@ -54,17 +53,15 @@ class StatsdMetricCollectionTests(testing.AsyncHTTPTestCase):
def setUp(self):
self.application = None
super(StatsdMetricCollectionTests, self).setUp()
self.statsd = FakeStatsdServer(self.io_loop)
super(TCPStatsdMetricCollectionTests, self).setUp()
self.statsd = FakeTCPStatsdServer(self.io_loop)
statsd.install(self.application, **{'namespace': 'testing',
'host': self.statsd.sockaddr[0],
'port': self.statsd.sockaddr[1],
'proto': 'tcp',
'prepend_metric_type': True})
def tearDown(self):
self.statsd.close()
super(StatsdMetricCollectionTests, self).tearDown()
def test_that_http_method_call_is_recorded(self):
response = self.fetch('/')
self.assertEqual(response.code, 204)
@ -111,7 +108,7 @@ class StatsdMetricCollectionTests(testing.AsyncHTTPTestCase):
list(self.statsd.find_metrics(expected, 'ms'))[0][0])
class StatsdConfigurationTests(testing.AsyncHTTPTestCase):
class TCPStatsdConfigurationTests(testing.AsyncHTTPTestCase):
def get_app(self):
self.application = web.Application([
@ -122,17 +119,126 @@ class StatsdConfigurationTests(testing.AsyncHTTPTestCase):
def setUp(self):
self.application = None
super(StatsdConfigurationTests, self).setUp()
self.statsd = FakeStatsdServer(self.io_loop)
super(TCPStatsdConfigurationTests, self).setUp()
self.statsd = FakeTCPStatsdServer(self.io_loop)
statsd.install(self.application, **{'namespace': 'testing',
'host': self.statsd.sockaddr[0],
'port': self.statsd.sockaddr[1],
'proto': 'tcp',
'prepend_metric_type': False})
def test_that_http_method_call_is_recorded(self):
response = self.fetch('/')
self.assertEqual(response.code, 204)
expected = 'testing.SimpleHandler.GET.204'
for path, value, stat_type in self.statsd.find_metrics(expected, 'ms'):
assert_between(250.0, float(value), 500.0)
def test_that_counter_accepts_increment_value(self):
response = self.fetch('/counters/path/5', method='POST', body='')
self.assertEqual(response.code, 204)
prefix = 'testing.path'
for path, value, stat_type in self.statsd.find_metrics(prefix, 'c'):
self.assertEqual(int(value), 5)
class UDPStatsdMetricCollectionTests(testing.AsyncHTTPTestCase):
def get_app(self):
self.application = web.Application([
web.url('/', examples.statsd.SimpleHandler),
web.url('/counters/(.*)/([.0-9]*)', CounterBumper),
web.url('/status_code', DefaultStatusCode),
])
return self.application
def setUp(self):
self.application = None
super(UDPStatsdMetricCollectionTests, self).setUp()
self.statsd = FakeUDPStatsdServer(self.io_loop)
statsd.install(self.application, **{'namespace': 'testing',
'host': self.statsd.sockaddr[0],
'port': self.statsd.sockaddr[1],
'proto': 'udp',
'prepend_metric_type': True})
def tearDown(self):
self.statsd.close()
super(UDPStatsdMetricCollectionTests, self).tearDown()
def test_that_http_method_call_is_recorded(self):
response = self.fetch('/')
self.assertEqual(response.code, 204)
expected = 'testing.timers.SimpleHandler.GET.204'
for path, value, stat_type in self.statsd.find_metrics(expected, 'ms'):
assert_between(250.0, float(value), 500.0)
def test_that_counter_increment_defaults_to_one(self):
response = self.fetch('/', method='POST', body='')
self.assertEqual(response.code, 204)
prefix = 'testing.counters.request.path'
for path, value, stat_type in self.statsd.find_metrics(prefix, 'c'):
self.assertEqual(int(value), 1)
def test_that_counter_accepts_increment_value(self):
response = self.fetch('/counters/path/5', method='POST', body='')
self.assertEqual(response.code, 204)
prefix = 'testing.counters.path'
for path, value, stat_type in self.statsd.find_metrics(prefix, 'c'):
self.assertEqual(int(value), 5)
def test_that_execution_timer_records_time_spent(self):
response = self.fetch('/counters/one.two.three/0.25')
self.assertEqual(response.code, 204)
prefix = 'testing.timers.one.two.three'
for path, value, stat_type in self.statsd.find_metrics(prefix, 'ms'):
assert_between(250.0, float(value), 300.0)
def test_that_add_metric_tag_is_ignored(self):
response = self.fetch('/',
headers={'Correlation-ID': 'does not matter'})
self.assertEqual(response.code, 204)
def test_that_status_code_is_used_when_not_explicitly_set(self):
response = self.fetch('/status_code')
self.assertEqual(response.code, 200)
expected = 'testing.timers.DefaultStatusCode.GET.200'
self.assertEqual(expected,
list(self.statsd.find_metrics(expected, 'ms'))[0][0])
class UDPStatsdConfigurationTests(testing.AsyncHTTPTestCase):
def get_app(self):
self.application = web.Application([
web.url('/', examples.statsd.SimpleHandler),
web.url('/counters/(.*)/([.0-9]*)', CounterBumper),
])
return self.application
def setUp(self):
self.application = None
super(UDPStatsdConfigurationTests, self).setUp()
self.statsd = FakeUDPStatsdServer(self.io_loop)
statsd.install(self.application, **{'namespace': 'testing',
'host': self.statsd.sockaddr[0],
'port': self.statsd.sockaddr[1],
'proto': 'udp',
'prepend_metric_type': False})
def tearDown(self):
self.statsd.close()
super(StatsdConfigurationTests, self).tearDown()
super(UDPStatsdConfigurationTests, self).tearDown()
def test_that_http_method_call_is_recorded(self):
response = self.fetch('/')