mirror of
https://github.com/sprockets/sprockets-statsd.git
synced 2025-01-13 11:21:48 +00:00
Refactor to pull out TCP details.
This will make it easier to add the UDP protocol.
This commit is contained in:
parent
6d310db517
commit
38a5e3f566
5 changed files with 256 additions and 113 deletions
|
@ -24,6 +24,9 @@ documentation for a description of the supported settings.
|
|||
Reference
|
||||
=========
|
||||
|
||||
.. autoclass:: sprockets_statsd.statsd.Connector
|
||||
:members:
|
||||
|
||||
Mixin classes
|
||||
-------------
|
||||
.. autoclass:: sprockets_statsd.mixins.Application
|
||||
|
@ -34,10 +37,13 @@ Mixin classes
|
|||
|
||||
Internals
|
||||
---------
|
||||
.. autoclass:: sprockets_statsd.statsd.Connector
|
||||
.. autoclass:: sprockets_statsd.statsd.Processor
|
||||
:members:
|
||||
|
||||
.. autoclass:: sprockets_statsd.statsd.Processor
|
||||
.. autoclass:: sprockets_statsd.statsd.StatsdProtocol
|
||||
:members:
|
||||
|
||||
.. autoclass:: sprockets_statsd.statsd.TCPProtocol
|
||||
:members:
|
||||
|
||||
Release history
|
||||
|
|
|
@ -51,6 +51,9 @@ nitpicky = 1
|
|||
warning_is_error = 1
|
||||
|
||||
[coverage:report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
raise NotImplementedError
|
||||
fail_under = 100
|
||||
show_missing = 1
|
||||
|
||||
|
|
|
@ -68,11 +68,104 @@ class Connector:
|
|||
internal queue for future processing.
|
||||
|
||||
"""
|
||||
payload = f'{path}:{value}|{type_code}\n'
|
||||
payload = f'{path}:{value}|{type_code}'
|
||||
self.processor.queue.put_nowait(payload.encode('utf-8'))
|
||||
|
||||
|
||||
class Processor(asyncio.Protocol):
|
||||
class StatsdProtocol:
|
||||
"""Common interface for backend protocols/transports.
|
||||
|
||||
UDP and TCP transports have different interfaces (sendto vs write)
|
||||
so this class adapts them to a common protocol that our code
|
||||
can depend on.
|
||||
|
||||
.. attribute:: buffered_data
|
||||
:type: bytes
|
||||
|
||||
Bytes that are buffered due to low-level transport failures.
|
||||
Since protocols & transports are created anew with each connect
|
||||
attempt, the :class:`.Processor` instance ensures that data
|
||||
buffered on a transport is copied over to the new transport
|
||||
when creating a connection.
|
||||
|
||||
.. attribute:: connected
|
||||
:type: asyncio.Event
|
||||
|
||||
Is the protocol currently connected?
|
||||
|
||||
"""
|
||||
logger: logging.Logger
|
||||
|
||||
def __init__(self):
|
||||
self.buffered_data = b''
|
||||
self.connected = asyncio.Event()
|
||||
self.logger = logging.getLogger(__package__).getChild(
|
||||
self.__class__.__name__)
|
||||
self.transport = None
|
||||
|
||||
def send(self, metric: bytes) -> None:
|
||||
"""Send a metric payload over the transport."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the transport and wait for it to close."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
"""Capture the new transport and set the connected event."""
|
||||
server, port = transport.get_extra_info('peername')
|
||||
self.logger.info('connected to statsd %s:%s', server, port)
|
||||
self.transport = transport
|
||||
self.connected.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]):
|
||||
"""Clear the connected event."""
|
||||
self.logger.warning('statsd server connection lost: %s', exc)
|
||||
self.connected.clear()
|
||||
|
||||
|
||||
class TCPProtocol(StatsdProtocol, asyncio.Protocol):
|
||||
"""StatsdProtocol implementation over a TCP/IP connection."""
|
||||
transport: asyncio.WriteTransport
|
||||
|
||||
def eof_received(self):
|
||||
self.logger.warning('received EOF from statsd server')
|
||||
self.connected.clear()
|
||||
|
||||
def send(self, metric: bytes) -> None:
|
||||
"""Send `metric` to the server.
|
||||
|
||||
If sending the metric fails, it will be saved in
|
||||
``self.buffered_data``. The processor will save and
|
||||
restore the buffered data if it needs to create a
|
||||
new protocol object.
|
||||
|
||||
"""
|
||||
if not self.buffered_data and not metric:
|
||||
return
|
||||
|
||||
self.buffered_data = self.buffered_data + metric + b'\n'
|
||||
while (self.transport is not None and self.connected.is_set()
|
||||
and self.buffered_data):
|
||||
line, maybe_nl, rest = self.buffered_data.partition(b'\n')
|
||||
line += maybe_nl
|
||||
self.transport.write(line)
|
||||
if self.transport.is_closing():
|
||||
self.logger.warning('transport closed during write')
|
||||
break
|
||||
self.buffered_data = rest
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Close the transport after flushing any outstanding data."""
|
||||
self.logger.info('shutting down')
|
||||
if self.connected.is_set():
|
||||
self.send(b'') # flush buffered data
|
||||
self.transport.close()
|
||||
while self.connected.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
class Processor:
|
||||
"""Maintains the statsd connection and sends metric payloads.
|
||||
|
||||
:param host: statsd server to send metrics to
|
||||
|
@ -112,11 +205,6 @@ class Processor(asyncio.Protocol):
|
|||
Formatted metric payloads to send to the statsd server. Enqueue
|
||||
payloads to send them to the server.
|
||||
|
||||
.. attribute:: connected
|
||||
:type: asyncio.Event
|
||||
|
||||
Is the TCP connection currently connected?
|
||||
|
||||
.. attribute:: running
|
||||
:type: asyncio.Event
|
||||
|
||||
|
@ -132,6 +220,9 @@ class Processor(asyncio.Protocol):
|
|||
until the task stops.
|
||||
|
||||
"""
|
||||
|
||||
protocol: typing.Union[StatsdProtocol, None]
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
host,
|
||||
|
@ -141,8 +232,13 @@ class Processor(asyncio.Protocol):
|
|||
super().__init__()
|
||||
if not host:
|
||||
raise RuntimeError('host must be set')
|
||||
if not port or port < 1:
|
||||
raise RuntimeError('port must be a positive integer')
|
||||
try:
|
||||
port = int(port)
|
||||
if not port or port < 1:
|
||||
raise RuntimeError(
|
||||
f'port must be a positive integer: {port!r}')
|
||||
except TypeError:
|
||||
raise RuntimeError(f'port must be a positive integer: {port!r}')
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
@ -152,14 +248,20 @@ class Processor(asyncio.Protocol):
|
|||
self.running = asyncio.Event()
|
||||
self.stopped = asyncio.Event()
|
||||
self.stopped.set()
|
||||
self.connected = asyncio.Event()
|
||||
self.logger = logging.getLogger(__package__).getChild('Processor')
|
||||
self.should_terminate = False
|
||||
self.transport = None
|
||||
self.protocol = None
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
self._failed_sends = []
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""Is the processor connected?"""
|
||||
if self.protocol is None:
|
||||
return False
|
||||
return self.protocol.connected.is_set()
|
||||
|
||||
async def run(self):
|
||||
"""Maintains the connection and processes metric payloads."""
|
||||
self.running.set()
|
||||
|
@ -168,7 +270,7 @@ class Processor(asyncio.Protocol):
|
|||
while not self.should_terminate:
|
||||
try:
|
||||
await self._connect_if_necessary()
|
||||
if self.connected.is_set():
|
||||
if self.connected:
|
||||
await self._process_metric()
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info('task cancelled, exiting')
|
||||
|
@ -177,17 +279,14 @@ class Processor(asyncio.Protocol):
|
|||
self.should_terminate = True
|
||||
self.logger.info('loop finished with %d metrics in the queue',
|
||||
self.queue.qsize())
|
||||
if self.connected.is_set():
|
||||
num_ready = self.queue.qsize()
|
||||
if self.connected:
|
||||
num_ready = max(self.queue.qsize(), 1)
|
||||
self.logger.info('draining %d metrics', num_ready)
|
||||
for _ in range(num_ready):
|
||||
await self._process_metric()
|
||||
self.logger.debug('closing transport')
|
||||
self.transport.close()
|
||||
|
||||
while self.connected.is_set():
|
||||
self.logger.debug('waiting on transport to close')
|
||||
await asyncio.sleep(0.1)
|
||||
if self.protocol is not None:
|
||||
await self.protocol.shutdown()
|
||||
|
||||
self.logger.info('processor is exiting')
|
||||
self.running.clear()
|
||||
|
@ -204,66 +303,46 @@ class Processor(asyncio.Protocol):
|
|||
self.should_terminate = True
|
||||
await self.stopped.wait()
|
||||
|
||||
def eof_received(self):
|
||||
self.logger.warning('received EOF from statsd server')
|
||||
self.connected.clear()
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
server, port = transport.get_extra_info('peername')
|
||||
self.logger.info('connected to statsd %s:%s', server, port)
|
||||
self.transport = transport
|
||||
self.connected.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]):
|
||||
self.logger.warning('statsd server connection lost')
|
||||
self.connected.clear()
|
||||
|
||||
async def _connect_if_necessary(self):
|
||||
try:
|
||||
await asyncio.wait_for(self.connected.wait(), self._wait_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if self.protocol is not None:
|
||||
try:
|
||||
self.logger.debug('starting connection to %s:%s', self.host,
|
||||
self.port)
|
||||
await asyncio.get_running_loop().create_connection(
|
||||
protocol_factory=lambda: self,
|
||||
await asyncio.wait_for(self.protocol.connected.wait(),
|
||||
self._wait_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.debug('protocol is no longer connected')
|
||||
|
||||
if not self.connected:
|
||||
try:
|
||||
buffered_data = b''
|
||||
if self.protocol is not None:
|
||||
buffered_data = self.protocol.buffered_data
|
||||
loop = asyncio.get_running_loop()
|
||||
transport, protocol = await loop.create_connection(
|
||||
protocol_factory=TCPProtocol,
|
||||
host=self.host,
|
||||
port=self.port)
|
||||
self.protocol = typing.cast(TCPProtocol, protocol)
|
||||
self.protocol.buffered_data = buffered_data
|
||||
self.logger.info('connection established to %s',
|
||||
transport.get_extra_info('peername'))
|
||||
except IOError as error:
|
||||
self.logger.warning('connection to %s:%s failed: %s',
|
||||
self.host, self.port, error)
|
||||
await asyncio.sleep(self._reconnect_sleep)
|
||||
|
||||
async def _process_metric(self):
|
||||
processing_failed_send = False
|
||||
if not self._failed_sends:
|
||||
try:
|
||||
metric = await asyncio.wait_for(self.queue.get(),
|
||||
self._wait_timeout)
|
||||
self.logger.debug('received %r from queue', metric)
|
||||
self.queue.task_done()
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
else:
|
||||
# Since we `await`d the state of the transport may have
|
||||
# changed. Sending on the closed transport won't return
|
||||
# an error since the send is async. We can catch the
|
||||
# problem here though.
|
||||
if self.transport.is_closing():
|
||||
self.logger.debug('preventing send on closed transport')
|
||||
self._failed_sends.append(metric)
|
||||
return
|
||||
else:
|
||||
self.logger.debug('using previous send attempt')
|
||||
metric = self._failed_sends[0]
|
||||
processing_failed_send = True
|
||||
try:
|
||||
metric = await asyncio.wait_for(self.queue.get(),
|
||||
self._wait_timeout)
|
||||
self.logger.debug('received %r from queue', metric)
|
||||
self.queue.task_done()
|
||||
except asyncio.TimeoutError:
|
||||
# we still want to invoke the protocol send in case
|
||||
# it has queued metrics to send
|
||||
metric = b''
|
||||
|
||||
self.transport.write(metric)
|
||||
if not self.transport.is_closing():
|
||||
self.logger.debug('sent %r to statsd', metric)
|
||||
if processing_failed_send:
|
||||
self._failed_sends.pop(0)
|
||||
else:
|
||||
# Writing to a transport does not raise exceptions, it
|
||||
# will close the transport if a low-level error occurs.
|
||||
self.logger.debug('transport closed by writing')
|
||||
try:
|
||||
self.protocol.send(metric)
|
||||
except Exception as error:
|
||||
self.logger.exception('exception occurred when sending metric: %s',
|
||||
error)
|
||||
|
|
|
@ -158,8 +158,12 @@ class RequestHandlerTests(testing.AsyncHTTPTestCase):
|
|||
timeout_remaining = testing.get_async_test_timeout()
|
||||
for _ in range(metric_count):
|
||||
start = time.time()
|
||||
self.io_loop.run_sync(self.statsd_server.message_received.acquire,
|
||||
timeout=timeout_remaining)
|
||||
try:
|
||||
self.io_loop.run_sync(
|
||||
self.statsd_server.message_received.acquire,
|
||||
timeout=timeout_remaining)
|
||||
except TimeoutError:
|
||||
self.fail()
|
||||
timeout_remaining -= (time.time() - start)
|
||||
|
||||
def parse_metric(self, metric_line: bytes) -> ParsedMetric:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
import asynctest
|
||||
|
@ -55,7 +56,7 @@ class ProcessorTests(ProcessorTestCase):
|
|||
self.statsd_server.close()
|
||||
await self.statsd_server.wait_closed()
|
||||
until = time.time() + self.test_timeout
|
||||
while processor.connected.is_set():
|
||||
while processor.connected:
|
||||
await asyncio.sleep(0.1)
|
||||
if time.time() >= until:
|
||||
self.fail('processor never disconnected')
|
||||
|
@ -63,7 +64,7 @@ class ProcessorTests(ProcessorTestCase):
|
|||
# Start the server on the same port and let the client reconnect.
|
||||
self.statsd_task = asyncio.create_task(self.statsd_server.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
self.assertTrue(processor.connected.is_set())
|
||||
self.assertTrue(processor.connected)
|
||||
|
||||
await self.wait_for(processor.stop())
|
||||
|
||||
|
@ -97,27 +98,6 @@ class ProcessorTests(ProcessorTestCase):
|
|||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
await self.wait_for(processor.stop())
|
||||
|
||||
async def test_connection_failures(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
# Change the port and close the transport, this will cause the
|
||||
# processor to reconnect to the new port and fail.
|
||||
processor.port = 1
|
||||
processor.transport.close()
|
||||
|
||||
# Wait for the processor to be disconnected, then change the
|
||||
# port back and let the processor reconnect.
|
||||
while processor.connected.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.2)
|
||||
processor.port = self.statsd_server.port
|
||||
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
await self.wait_for(processor.stop())
|
||||
|
||||
async def test_that_stopping_when_not_running_is_safe(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
|
@ -141,6 +121,92 @@ class ProcessorTests(ProcessorTestCase):
|
|||
statsd.Processor(host='localhost', port=-1)
|
||||
self.assertIn('port', str(context.exception))
|
||||
|
||||
async def test_starting_and_stopping_without_connecting(self):
|
||||
host, port = self.statsd_server.host, self.statsd_server.port
|
||||
self.statsd_server.close()
|
||||
await self.wait_for(self.statsd_server.wait_closed())
|
||||
processor = statsd.Processor(host=host, port=port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(processor.running.wait())
|
||||
await processor.stop()
|
||||
|
||||
async def test_that_protocol_exceptions_are_logged(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(processor.running.wait())
|
||||
|
||||
with self.assertLogs(processor.logger, level=logging.ERROR) as cm:
|
||||
processor.queue.put_nowait('not-bytes')
|
||||
while processor.queue.qsize() > 0:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
for record in cm.records:
|
||||
if (record.exc_info is not None
|
||||
and record.funcName == '_process_metric'):
|
||||
break
|
||||
else:
|
||||
self.fail('Expected _process_metric to log exception')
|
||||
|
||||
await processor.stop()
|
||||
|
||||
|
||||
class TCPProcessingTests(ProcessorTestCase):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(self.processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.processor.stop()
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def test_connection_failures(self):
|
||||
# Change the port and close the transport, this will cause the
|
||||
# processor to reconnect to the new port and fail.
|
||||
self.processor.port = 1
|
||||
self.processor.protocol.transport.close()
|
||||
|
||||
# Wait for the processor to be disconnected, then change the
|
||||
# port back and let the processor reconnect.
|
||||
while self.processor.connected:
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.2)
|
||||
self.processor.port = self.statsd_server.port
|
||||
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
async def test_socket_closure_while_processing_failed_event(self):
|
||||
state = {'first_time': True}
|
||||
real_process_metric = self.processor._process_metric
|
||||
|
||||
async def fake_process_metric():
|
||||
if state['first_time']:
|
||||
self.processor.protocol.buffered_data = b'counter:1|c\n'
|
||||
self.processor.protocol.transport.close()
|
||||
state['first_time'] = False
|
||||
return await real_process_metric()
|
||||
|
||||
self.processor._process_metric = fake_process_metric
|
||||
|
||||
await self.wait_for(self.statsd_server.message_received.acquire())
|
||||
|
||||
async def test_socket_closure_while_sending(self):
|
||||
state = {'first_time': True}
|
||||
real_transport_write = self.processor.protocol.transport.write
|
||||
|
||||
def fake_transport_write(buffer):
|
||||
if state['first_time']:
|
||||
self.processor.protocol.transport.close()
|
||||
state['first_time'] = False
|
||||
return real_transport_write(buffer)
|
||||
|
||||
self.processor.protocol.transport.write = fake_transport_write
|
||||
self.processor.queue.put_nowait(b'counter:1|c')
|
||||
await self.wait_for(self.statsd_server.message_received.acquire())
|
||||
|
||||
|
||||
class ConnectorTests(ProcessorTestCase):
|
||||
async def asyncSetUp(self):
|
||||
|
@ -223,18 +289,3 @@ class ConnectorTests(ProcessorTestCase):
|
|||
await self.wait_for(self.statsd_server.message_received.acquire())
|
||||
self.assertEqual(f'counter:{value}|c'.encode(),
|
||||
self.statsd_server.metrics.pop(0))
|
||||
|
||||
async def test_socket_closure_while_processing_failed_event(self):
|
||||
state = {'first_time': True}
|
||||
real_process_metric = self.connector.processor._process_metric
|
||||
|
||||
async def fake_process_metric():
|
||||
if state['first_time']:
|
||||
self.connector.processor._failed_sends.append(b'counter:1|c\n')
|
||||
self.connector.processor.transport.close()
|
||||
state['first_time'] = False
|
||||
return await real_process_metric()
|
||||
|
||||
self.connector.processor._process_metric = fake_process_metric
|
||||
|
||||
await self.wait_for(self.statsd_server.message_received.acquire())
|
||||
|
|
Loading…
Reference in a new issue