Refactor to pull out TCP details.

This will make it easier to add the UDP protocol.
This commit is contained in:
Dave Shawley 2021-03-19 18:19:03 -04:00
parent 6d310db517
commit 38a5e3f566
No known key found for this signature in database
GPG key ID: 44A9C9992CCFAB82
5 changed files with 256 additions and 113 deletions

View file

@ -24,6 +24,9 @@ documentation for a description of the supported settings.
Reference
=========
.. autoclass:: sprockets_statsd.statsd.Connector
:members:
Mixin classes
-------------
.. autoclass:: sprockets_statsd.mixins.Application
@ -34,10 +37,13 @@ Mixin classes
Internals
---------
.. autoclass:: sprockets_statsd.statsd.Connector
.. autoclass:: sprockets_statsd.statsd.Processor
:members:
.. autoclass:: sprockets_statsd.statsd.Processor
.. autoclass:: sprockets_statsd.statsd.StatsdProtocol
:members:
.. autoclass:: sprockets_statsd.statsd.TCPProtocol
:members:
Release history

View file

@ -51,6 +51,9 @@ nitpicky = 1
warning_is_error = 1
[coverage:report]
exclude_lines =
pragma: no cover
raise NotImplementedError
fail_under = 100
show_missing = 1

View file

@ -68,11 +68,104 @@ class Connector:
internal queue for future processing.
"""
payload = f'{path}:{value}|{type_code}\n'
payload = f'{path}:{value}|{type_code}'
self.processor.queue.put_nowait(payload.encode('utf-8'))
class Processor(asyncio.Protocol):
class StatsdProtocol:
"""Common interface for backend protocols/transports.
UDP and TCP transports have different interfaces (sendto vs write)
so this class adapts them to a common protocol that our code
can depend on.
.. attribute:: buffered_data
:type: bytes
Bytes that are buffered due to low-level transport failures.
Since protocols & transports are created anew with each connect
attempt, the :class:`.Processor` instance ensures that data
buffered on a transport is copied over to the new transport
when creating a connection.
.. attribute:: connected
:type: asyncio.Event
Is the protocol currently connected?
"""
logger: logging.Logger
def __init__(self):
self.buffered_data = b''
self.connected = asyncio.Event()
self.logger = logging.getLogger(__package__).getChild(
self.__class__.__name__)
self.transport = None
def send(self, metric: bytes) -> None:
"""Send a metric payload over the transport."""
raise NotImplementedError()
async def shutdown(self) -> None:
"""Shutdown the transport and wait for it to close."""
raise NotImplementedError()
def connection_made(self, transport: asyncio.Transport):
"""Capture the new transport and set the connected event."""
server, port = transport.get_extra_info('peername')
self.logger.info('connected to statsd %s:%s', server, port)
self.transport = transport
self.connected.set()
def connection_lost(self, exc: typing.Optional[Exception]):
"""Clear the connected event."""
self.logger.warning('statsd server connection lost: %s', exc)
self.connected.clear()
class TCPProtocol(StatsdProtocol, asyncio.Protocol):
"""StatsdProtocol implementation over a TCP/IP connection."""
transport: asyncio.WriteTransport
def eof_received(self):
self.logger.warning('received EOF from statsd server')
self.connected.clear()
def send(self, metric: bytes) -> None:
"""Send `metric` to the server.
If sending the metric fails, it will be saved in
``self.buffered_data``. The processor will save and
restore the buffered data if it needs to create a
new protocol object.
"""
if not self.buffered_data and not metric:
return
self.buffered_data = self.buffered_data + metric + b'\n'
while (self.transport is not None and self.connected.is_set()
and self.buffered_data):
line, maybe_nl, rest = self.buffered_data.partition(b'\n')
line += maybe_nl
self.transport.write(line)
if self.transport.is_closing():
self.logger.warning('transport closed during write')
break
self.buffered_data = rest
async def shutdown(self) -> None:
"""Close the transport after flushing any outstanding data."""
self.logger.info('shutting down')
if self.connected.is_set():
self.send(b'') # flush buffered data
self.transport.close()
while self.connected.is_set():
await asyncio.sleep(0.1)
class Processor:
"""Maintains the statsd connection and sends metric payloads.
:param host: statsd server to send metrics to
@ -112,11 +205,6 @@ class Processor(asyncio.Protocol):
Formatted metric payloads to send to the statsd server. Enqueue
payloads to send them to the server.
.. attribute:: connected
:type: asyncio.Event
Is the TCP connection currently connected?
.. attribute:: running
:type: asyncio.Event
@ -132,6 +220,9 @@ class Processor(asyncio.Protocol):
until the task stops.
"""
protocol: typing.Union[StatsdProtocol, None]
def __init__(self,
*,
host,
@ -141,8 +232,13 @@ class Processor(asyncio.Protocol):
super().__init__()
if not host:
raise RuntimeError('host must be set')
if not port or port < 1:
raise RuntimeError('port must be a positive integer')
try:
port = int(port)
if not port or port < 1:
raise RuntimeError(
f'port must be a positive integer: {port!r}')
except TypeError:
raise RuntimeError(f'port must be a positive integer: {port!r}')
self.host = host
self.port = port
@ -152,14 +248,20 @@ class Processor(asyncio.Protocol):
self.running = asyncio.Event()
self.stopped = asyncio.Event()
self.stopped.set()
self.connected = asyncio.Event()
self.logger = logging.getLogger(__package__).getChild('Processor')
self.should_terminate = False
self.transport = None
self.protocol = None
self.queue = asyncio.Queue()
self._failed_sends = []
@property
def connected(self) -> bool:
"""Is the processor connected?"""
if self.protocol is None:
return False
return self.protocol.connected.is_set()
async def run(self):
"""Maintains the connection and processes metric payloads."""
self.running.set()
@ -168,7 +270,7 @@ class Processor(asyncio.Protocol):
while not self.should_terminate:
try:
await self._connect_if_necessary()
if self.connected.is_set():
if self.connected:
await self._process_metric()
except asyncio.CancelledError:
self.logger.info('task cancelled, exiting')
@ -177,17 +279,14 @@ class Processor(asyncio.Protocol):
self.should_terminate = True
self.logger.info('loop finished with %d metrics in the queue',
self.queue.qsize())
if self.connected.is_set():
num_ready = self.queue.qsize()
if self.connected:
num_ready = max(self.queue.qsize(), 1)
self.logger.info('draining %d metrics', num_ready)
for _ in range(num_ready):
await self._process_metric()
self.logger.debug('closing transport')
self.transport.close()
while self.connected.is_set():
self.logger.debug('waiting on transport to close')
await asyncio.sleep(0.1)
if self.protocol is not None:
await self.protocol.shutdown()
self.logger.info('processor is exiting')
self.running.clear()
@ -204,66 +303,46 @@ class Processor(asyncio.Protocol):
self.should_terminate = True
await self.stopped.wait()
def eof_received(self):
self.logger.warning('received EOF from statsd server')
self.connected.clear()
def connection_made(self, transport: asyncio.Transport):
server, port = transport.get_extra_info('peername')
self.logger.info('connected to statsd %s:%s', server, port)
self.transport = transport
self.connected.set()
def connection_lost(self, exc: typing.Optional[Exception]):
self.logger.warning('statsd server connection lost')
self.connected.clear()
async def _connect_if_necessary(self):
try:
await asyncio.wait_for(self.connected.wait(), self._wait_timeout)
except asyncio.TimeoutError:
if self.protocol is not None:
try:
self.logger.debug('starting connection to %s:%s', self.host,
self.port)
await asyncio.get_running_loop().create_connection(
protocol_factory=lambda: self,
await asyncio.wait_for(self.protocol.connected.wait(),
self._wait_timeout)
except asyncio.TimeoutError:
self.logger.debug('protocol is no longer connected')
if not self.connected:
try:
buffered_data = b''
if self.protocol is not None:
buffered_data = self.protocol.buffered_data
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_connection(
protocol_factory=TCPProtocol,
host=self.host,
port=self.port)
self.protocol = typing.cast(TCPProtocol, protocol)
self.protocol.buffered_data = buffered_data
self.logger.info('connection established to %s',
transport.get_extra_info('peername'))
except IOError as error:
self.logger.warning('connection to %s:%s failed: %s',
self.host, self.port, error)
await asyncio.sleep(self._reconnect_sleep)
async def _process_metric(self):
processing_failed_send = False
if not self._failed_sends:
try:
metric = await asyncio.wait_for(self.queue.get(),
self._wait_timeout)
self.logger.debug('received %r from queue', metric)
self.queue.task_done()
except asyncio.TimeoutError:
return
else:
# Since we `await`d the state of the transport may have
# changed. Sending on the closed transport won't return
# an error since the send is async. We can catch the
# problem here though.
if self.transport.is_closing():
self.logger.debug('preventing send on closed transport')
self._failed_sends.append(metric)
return
else:
self.logger.debug('using previous send attempt')
metric = self._failed_sends[0]
processing_failed_send = True
try:
metric = await asyncio.wait_for(self.queue.get(),
self._wait_timeout)
self.logger.debug('received %r from queue', metric)
self.queue.task_done()
except asyncio.TimeoutError:
# we still want to invoke the protocol send in case
# it has queued metrics to send
metric = b''
self.transport.write(metric)
if not self.transport.is_closing():
self.logger.debug('sent %r to statsd', metric)
if processing_failed_send:
self._failed_sends.pop(0)
else:
# Writing to a transport does not raise exceptions, it
# will close the transport if a low-level error occurs.
self.logger.debug('transport closed by writing')
try:
self.protocol.send(metric)
except Exception as error:
self.logger.exception('exception occurred when sending metric: %s',
error)

View file

@ -158,8 +158,12 @@ class RequestHandlerTests(testing.AsyncHTTPTestCase):
timeout_remaining = testing.get_async_test_timeout()
for _ in range(metric_count):
start = time.time()
self.io_loop.run_sync(self.statsd_server.message_received.acquire,
timeout=timeout_remaining)
try:
self.io_loop.run_sync(
self.statsd_server.message_received.acquire,
timeout=timeout_remaining)
except TimeoutError:
self.fail()
timeout_remaining -= (time.time() - start)
def parse_metric(self, metric_line: bytes) -> ParsedMetric:

View file

@ -1,4 +1,5 @@
import asyncio
import logging
import time
import asynctest
@ -55,7 +56,7 @@ class ProcessorTests(ProcessorTestCase):
self.statsd_server.close()
await self.statsd_server.wait_closed()
until = time.time() + self.test_timeout
while processor.connected.is_set():
while processor.connected:
await asyncio.sleep(0.1)
if time.time() >= until:
self.fail('processor never disconnected')
@ -63,7 +64,7 @@ class ProcessorTests(ProcessorTestCase):
# Start the server on the same port and let the client reconnect.
self.statsd_task = asyncio.create_task(self.statsd_server.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
self.assertTrue(processor.connected.is_set())
self.assertTrue(processor.connected)
await self.wait_for(processor.stop())
@ -97,27 +98,6 @@ class ProcessorTests(ProcessorTestCase):
await self.wait_for(self.statsd_server.client_connected.acquire())
await self.wait_for(processor.stop())
async def test_connection_failures(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
# Change the port and close the transport, this will cause the
# processor to reconnect to the new port and fail.
processor.port = 1
processor.transport.close()
# Wait for the processor to be disconnected, then change the
# port back and let the processor reconnect.
while processor.connected.is_set():
await asyncio.sleep(0.1)
await asyncio.sleep(0.2)
processor.port = self.statsd_server.port
await self.wait_for(self.statsd_server.client_connected.acquire())
await self.wait_for(processor.stop())
async def test_that_stopping_when_not_running_is_safe(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
@ -141,6 +121,92 @@ class ProcessorTests(ProcessorTestCase):
statsd.Processor(host='localhost', port=-1)
self.assertIn('port', str(context.exception))
async def test_starting_and_stopping_without_connecting(self):
host, port = self.statsd_server.host, self.statsd_server.port
self.statsd_server.close()
await self.wait_for(self.statsd_server.wait_closed())
processor = statsd.Processor(host=host, port=port)
asyncio.create_task(processor.run())
await self.wait_for(processor.running.wait())
await processor.stop()
async def test_that_protocol_exceptions_are_logged(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(processor.running.wait())
with self.assertLogs(processor.logger, level=logging.ERROR) as cm:
processor.queue.put_nowait('not-bytes')
while processor.queue.qsize() > 0:
await asyncio.sleep(0.1)
for record in cm.records:
if (record.exc_info is not None
and record.funcName == '_process_metric'):
break
else:
self.fail('Expected _process_metric to log exception')
await processor.stop()
class TCPProcessingTests(ProcessorTestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(self.processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
async def asyncTearDown(self):
await self.processor.stop()
await super().asyncTearDown()
async def test_connection_failures(self):
# Change the port and close the transport, this will cause the
# processor to reconnect to the new port and fail.
self.processor.port = 1
self.processor.protocol.transport.close()
# Wait for the processor to be disconnected, then change the
# port back and let the processor reconnect.
while self.processor.connected:
await asyncio.sleep(0.1)
await asyncio.sleep(0.2)
self.processor.port = self.statsd_server.port
await self.wait_for(self.statsd_server.client_connected.acquire())
async def test_socket_closure_while_processing_failed_event(self):
state = {'first_time': True}
real_process_metric = self.processor._process_metric
async def fake_process_metric():
if state['first_time']:
self.processor.protocol.buffered_data = b'counter:1|c\n'
self.processor.protocol.transport.close()
state['first_time'] = False
return await real_process_metric()
self.processor._process_metric = fake_process_metric
await self.wait_for(self.statsd_server.message_received.acquire())
async def test_socket_closure_while_sending(self):
state = {'first_time': True}
real_transport_write = self.processor.protocol.transport.write
def fake_transport_write(buffer):
if state['first_time']:
self.processor.protocol.transport.close()
state['first_time'] = False
return real_transport_write(buffer)
self.processor.protocol.transport.write = fake_transport_write
self.processor.queue.put_nowait(b'counter:1|c')
await self.wait_for(self.statsd_server.message_received.acquire())
class ConnectorTests(ProcessorTestCase):
async def asyncSetUp(self):
@ -223,18 +289,3 @@ class ConnectorTests(ProcessorTestCase):
await self.wait_for(self.statsd_server.message_received.acquire())
self.assertEqual(f'counter:{value}|c'.encode(),
self.statsd_server.metrics.pop(0))
async def test_socket_closure_while_processing_failed_event(self):
state = {'first_time': True}
real_process_metric = self.connector.processor._process_metric
async def fake_process_metric():
if state['first_time']:
self.connector.processor._failed_sends.append(b'counter:1|c\n')
self.connector.processor.transport.close()
state['first_time'] = False
return await real_process_metric()
self.connector.processor._process_metric = fake_process_metric
await self.wait_for(self.statsd_server.message_received.acquire())