From 9febb7e7e840f25309afefd5822ebb58a4398644 Mon Sep 17 00:00:00 2001 From: Dave Shawley Date: Sat, 6 Mar 2021 09:50:29 -0500 Subject: [PATCH] statsd.Processor that reconnects. --- sprockets_statsd/statsd.py | 71 +++++++++++++++++++++++ tests/helpers.py | 66 ++++++++++++++++++++++ tests/test_processor.py | 112 +++++++++++++++++++++++++++++++++++++ 3 files changed, 249 insertions(+) create mode 100644 sprockets_statsd/statsd.py create mode 100644 tests/helpers.py create mode 100644 tests/test_processor.py diff --git a/sprockets_statsd/statsd.py b/sprockets_statsd/statsd.py new file mode 100644 index 0000000..2f450ca --- /dev/null +++ b/sprockets_statsd/statsd.py @@ -0,0 +1,71 @@ +import asyncio +import logging +import typing + + +class Processor(asyncio.Protocol): + def __init__(self, *, host, port: int = 8125, **kwargs): + super().__init__(**kwargs) + self.host = host + self.port = port + + self.closed = asyncio.Event() + self.connected = asyncio.Event() + self.logger = logging.getLogger(__package__).getChild('Processor') + self.running = False + self.transport = None + + async def run(self): + self.running = True + while self.running: + try: + await self._connect_if_necessary() + await asyncio.sleep(0.1) + except asyncio.CancelledError: + self.logger.info('task cancelled, exiting') + break + + self.running = False + if self.connected.is_set(): + self.logger.debug('closing transport') + self.transport.close() + + while self.connected.is_set(): + self.logger.debug('waiting on transport to close') + await asyncio.sleep(0.1) + + self.logger.info('processing is exiting') + self.closed.set() + + async def stop(self): + self.running = False + await self.closed.wait() + + def eof_received(self): + self.logger.warning('received EOF from statsd server') + self.connected.clear() + + def connection_made(self, transport: asyncio.Transport): + server, port = transport.get_extra_info('peername') + self.logger.info('connected to statsd %s:%s', server, port) + self.transport = transport + self.connected.set() + + def connection_lost(self, exc: typing.Optional[Exception]): + self.logger.warning('statsd server connection lost') + self.connected.clear() + + async def _connect_if_necessary(self, wait_time: float = 0.1): + try: + await asyncio.wait_for(self.connected.wait(), wait_time) + except asyncio.TimeoutError: + try: + self.logger.debug('starting connection to %s:%s', self.host, + self.port) + await asyncio.get_running_loop().create_connection( + protocol_factory=lambda: self, + host=self.host, + port=self.port) + except IOError as error: + self.logger.warning('connection to %s:%s failed: %s', + self.host, self.port, error) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..fad9a39 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,66 @@ +import asyncio +import io + + +class StatsdServer(asyncio.Protocol): + def __init__(self): + self.service = None + self.host = '127.0.0.1' + self.port = 0 + self.connections_made = 0 + self.connections_lost = 0 + self.message_counter = 0 + + self.buffer = io.BytesIO() + self.running = asyncio.Event() + self.client_connected = asyncio.Semaphore(value=0) + self.message_received = asyncio.Semaphore(value=0) + self.transports: list[asyncio.Transport] = [] + + async def run(self): + loop = asyncio.get_running_loop() + self.service = await loop.create_server(lambda: self, + self.host, + self.port, + reuse_port=True) + listening_sock = self.service.sockets[0] + self.host, self.port = listening_sock.getsockname() + self.running.set() + try: + await self.service.serve_forever() + self.running.clear() + except asyncio.CancelledError: + self.close() + await self.service.wait_closed() + except Exception as error: + raise error + + def close(self): + self.running.clear() + self.service.close() + for connected_client in self.transports: + connected_client.close() + self.transports.clear() + + async def wait_running(self): + await self.running.wait() + + async def wait_closed(self): + if self.service.is_serving(): + self.close() + await self.service.wait_closed() + while self.running.is_set(): + await asyncio.sleep(0.1) + + def connection_made(self, transport: asyncio.Transport): + self.client_connected.release() + self.connections_made += 1 + self.transports.append(transport) + + def connection_lost(self, exc) -> None: + self.connections_lost += 1 + + def data_received(self, data: bytes): + self.buffer.write(data) + self.message_received.release() + self.message_counter += 1 diff --git a/tests/test_processor.py b/tests/test_processor.py new file mode 100644 index 0000000..e8e73d9 --- /dev/null +++ b/tests/test_processor.py @@ -0,0 +1,112 @@ +import asyncio +import time +import unittest + +from sprockets_statsd import statsd + +from tests import helpers + + +class ProcessorTests(unittest.IsolatedAsyncioTestCase): + def setUp(self): + super().setUp() + self.test_timeout = 5.0 + + async def asyncSetUp(self): + await super().asyncSetUp() + self.statsd_server = helpers.StatsdServer() + self.statsd_task = asyncio.create_task(self.statsd_server.run()) + await self.statsd_server.wait_running() + + async def asyncTearDown(self): + self.statsd_task.cancel() + await self.statsd_server.wait_closed() + await super().asyncTearDown() + + async def wait_for(self, fut): + try: + await asyncio.wait_for(fut, timeout=self.test_timeout) + except asyncio.TimeoutError: + self.fail('future took too long to resolve') + + async def test_that_processor_connects_and_disconnects(self): + processor = statsd.Processor(host=self.statsd_server.host, + port=self.statsd_server.port) + asyncio.create_task(processor.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + await self.wait_for(processor.stop()) + + self.assertEqual(1, self.statsd_server.connections_made) + self.assertEqual(1, self.statsd_server.connections_lost) + + async def test_that_processor_reconnects(self): + processor = statsd.Processor(host=self.statsd_server.host, + port=self.statsd_server.port) + asyncio.create_task(processor.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + + # Now that the server is running and the client has connected, + # cancel the server and let it die off. + self.statsd_server.close() + await self.statsd_server.wait_closed() + until = time.time() + self.test_timeout + while processor.connected.is_set(): + await asyncio.sleep(0.1) + if time.time() >= until: + self.fail('processor never disconnected') + + # Start the server on the same port and let the client reconnect. + self.statsd_task = asyncio.create_task(self.statsd_server.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + self.assertTrue(processor.connected.is_set()) + + await self.wait_for(processor.stop()) + + async def test_that_processor_can_be_cancelled(self): + processor = statsd.Processor(host=self.statsd_server.host, + port=self.statsd_server.port) + task = asyncio.create_task(processor.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + + task.cancel() + await self.wait_for(processor.closed.wait()) + + async def test_shutdown_when_disconnected(self): + processor = statsd.Processor(host=self.statsd_server.host, + port=self.statsd_server.port) + asyncio.create_task(processor.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + + self.statsd_server.close() + await self.statsd_server.wait_closed() + + await self.wait_for(processor.stop()) + + async def test_socket_resets(self): + processor = statsd.Processor(host=self.statsd_server.host, + port=self.statsd_server.port) + asyncio.create_task(processor.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + + self.statsd_server.transports[0].close() + await self.wait_for(self.statsd_server.client_connected.acquire()) + + async def test_connection_failures(self): + processor = statsd.Processor(host=self.statsd_server.host, + port=self.statsd_server.port) + asyncio.create_task(processor.run()) + await self.wait_for(self.statsd_server.client_connected.acquire()) + + # Change the port and close the transport, this will cause the + # processor to reconnect to the new port and fail. + processor.port = 1 + processor.transport.close() + + # Wait for the processor to be disconnected, then change the + # port back and let the processor reconnect. + while processor.connected.is_set(): + await asyncio.sleep(0.1) + await asyncio.sleep(0.2) + processor.port = self.statsd_server.port + + await self.wait_for(self.statsd_server.client_connected.acquire())