statsd.Processor that reconnects.

This commit is contained in:
Dave Shawley 2021-03-06 09:50:29 -05:00
parent 832f8af7e0
commit 9febb7e7e8
No known key found for this signature in database
GPG key ID: 44A9C9992CCFAB82
3 changed files with 249 additions and 0 deletions

View file

@ -0,0 +1,71 @@
import asyncio
import logging
import typing
class Processor(asyncio.Protocol):
def __init__(self, *, host, port: int = 8125, **kwargs):
super().__init__(**kwargs)
self.host = host
self.port = port
self.closed = asyncio.Event()
self.connected = asyncio.Event()
self.logger = logging.getLogger(__package__).getChild('Processor')
self.running = False
self.transport = None
async def run(self):
self.running = True
while self.running:
try:
await self._connect_if_necessary()
await asyncio.sleep(0.1)
except asyncio.CancelledError:
self.logger.info('task cancelled, exiting')
break
self.running = False
if self.connected.is_set():
self.logger.debug('closing transport')
self.transport.close()
while self.connected.is_set():
self.logger.debug('waiting on transport to close')
await asyncio.sleep(0.1)
self.logger.info('processing is exiting')
self.closed.set()
async def stop(self):
self.running = False
await self.closed.wait()
def eof_received(self):
self.logger.warning('received EOF from statsd server')
self.connected.clear()
def connection_made(self, transport: asyncio.Transport):
server, port = transport.get_extra_info('peername')
self.logger.info('connected to statsd %s:%s', server, port)
self.transport = transport
self.connected.set()
def connection_lost(self, exc: typing.Optional[Exception]):
self.logger.warning('statsd server connection lost')
self.connected.clear()
async def _connect_if_necessary(self, wait_time: float = 0.1):
try:
await asyncio.wait_for(self.connected.wait(), wait_time)
except asyncio.TimeoutError:
try:
self.logger.debug('starting connection to %s:%s', self.host,
self.port)
await asyncio.get_running_loop().create_connection(
protocol_factory=lambda: self,
host=self.host,
port=self.port)
except IOError as error:
self.logger.warning('connection to %s:%s failed: %s',
self.host, self.port, error)

66
tests/helpers.py Normal file
View file

@ -0,0 +1,66 @@
import asyncio
import io
class StatsdServer(asyncio.Protocol):
def __init__(self):
self.service = None
self.host = '127.0.0.1'
self.port = 0
self.connections_made = 0
self.connections_lost = 0
self.message_counter = 0
self.buffer = io.BytesIO()
self.running = asyncio.Event()
self.client_connected = asyncio.Semaphore(value=0)
self.message_received = asyncio.Semaphore(value=0)
self.transports: list[asyncio.Transport] = []
async def run(self):
loop = asyncio.get_running_loop()
self.service = await loop.create_server(lambda: self,
self.host,
self.port,
reuse_port=True)
listening_sock = self.service.sockets[0]
self.host, self.port = listening_sock.getsockname()
self.running.set()
try:
await self.service.serve_forever()
self.running.clear()
except asyncio.CancelledError:
self.close()
await self.service.wait_closed()
except Exception as error:
raise error
def close(self):
self.running.clear()
self.service.close()
for connected_client in self.transports:
connected_client.close()
self.transports.clear()
async def wait_running(self):
await self.running.wait()
async def wait_closed(self):
if self.service.is_serving():
self.close()
await self.service.wait_closed()
while self.running.is_set():
await asyncio.sleep(0.1)
def connection_made(self, transport: asyncio.Transport):
self.client_connected.release()
self.connections_made += 1
self.transports.append(transport)
def connection_lost(self, exc) -> None:
self.connections_lost += 1
def data_received(self, data: bytes):
self.buffer.write(data)
self.message_received.release()
self.message_counter += 1

112
tests/test_processor.py Normal file
View file

@ -0,0 +1,112 @@
import asyncio
import time
import unittest
from sprockets_statsd import statsd
from tests import helpers
class ProcessorTests(unittest.IsolatedAsyncioTestCase):
def setUp(self):
super().setUp()
self.test_timeout = 5.0
async def asyncSetUp(self):
await super().asyncSetUp()
self.statsd_server = helpers.StatsdServer()
self.statsd_task = asyncio.create_task(self.statsd_server.run())
await self.statsd_server.wait_running()
async def asyncTearDown(self):
self.statsd_task.cancel()
await self.statsd_server.wait_closed()
await super().asyncTearDown()
async def wait_for(self, fut):
try:
await asyncio.wait_for(fut, timeout=self.test_timeout)
except asyncio.TimeoutError:
self.fail('future took too long to resolve')
async def test_that_processor_connects_and_disconnects(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
await self.wait_for(processor.stop())
self.assertEqual(1, self.statsd_server.connections_made)
self.assertEqual(1, self.statsd_server.connections_lost)
async def test_that_processor_reconnects(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
# Now that the server is running and the client has connected,
# cancel the server and let it die off.
self.statsd_server.close()
await self.statsd_server.wait_closed()
until = time.time() + self.test_timeout
while processor.connected.is_set():
await asyncio.sleep(0.1)
if time.time() >= until:
self.fail('processor never disconnected')
# Start the server on the same port and let the client reconnect.
self.statsd_task = asyncio.create_task(self.statsd_server.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
self.assertTrue(processor.connected.is_set())
await self.wait_for(processor.stop())
async def test_that_processor_can_be_cancelled(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
task = asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
task.cancel()
await self.wait_for(processor.closed.wait())
async def test_shutdown_when_disconnected(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
self.statsd_server.close()
await self.statsd_server.wait_closed()
await self.wait_for(processor.stop())
async def test_socket_resets(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
self.statsd_server.transports[0].close()
await self.wait_for(self.statsd_server.client_connected.acquire())
async def test_connection_failures(self):
processor = statsd.Processor(host=self.statsd_server.host,
port=self.statsd_server.port)
asyncio.create_task(processor.run())
await self.wait_for(self.statsd_server.client_connected.acquire())
# Change the port and close the transport, this will cause the
# processor to reconnect to the new port and fail.
processor.port = 1
processor.transport.close()
# Wait for the processor to be disconnected, then change the
# port back and let the processor reconnect.
while processor.connected.is_set():
await asyncio.sleep(0.1)
await asyncio.sleep(0.2)
processor.port = self.statsd_server.port
await self.wait_for(self.statsd_server.client_connected.acquire())