mirror of
https://github.com/sprockets/sprockets-statsd.git
synced 2024-11-15 03:00:25 +00:00
statsd.Processor that reconnects.
This commit is contained in:
parent
832f8af7e0
commit
9febb7e7e8
3 changed files with 249 additions and 0 deletions
71
sprockets_statsd/statsd.py
Normal file
71
sprockets_statsd/statsd.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
|
||||
|
||||
class Processor(asyncio.Protocol):
|
||||
def __init__(self, *, host, port: int = 8125, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.closed = asyncio.Event()
|
||||
self.connected = asyncio.Event()
|
||||
self.logger = logging.getLogger(__package__).getChild('Processor')
|
||||
self.running = False
|
||||
self.transport = None
|
||||
|
||||
async def run(self):
|
||||
self.running = True
|
||||
while self.running:
|
||||
try:
|
||||
await self._connect_if_necessary()
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info('task cancelled, exiting')
|
||||
break
|
||||
|
||||
self.running = False
|
||||
if self.connected.is_set():
|
||||
self.logger.debug('closing transport')
|
||||
self.transport.close()
|
||||
|
||||
while self.connected.is_set():
|
||||
self.logger.debug('waiting on transport to close')
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.logger.info('processing is exiting')
|
||||
self.closed.set()
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
await self.closed.wait()
|
||||
|
||||
def eof_received(self):
|
||||
self.logger.warning('received EOF from statsd server')
|
||||
self.connected.clear()
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
server, port = transport.get_extra_info('peername')
|
||||
self.logger.info('connected to statsd %s:%s', server, port)
|
||||
self.transport = transport
|
||||
self.connected.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]):
|
||||
self.logger.warning('statsd server connection lost')
|
||||
self.connected.clear()
|
||||
|
||||
async def _connect_if_necessary(self, wait_time: float = 0.1):
|
||||
try:
|
||||
await asyncio.wait_for(self.connected.wait(), wait_time)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
self.logger.debug('starting connection to %s:%s', self.host,
|
||||
self.port)
|
||||
await asyncio.get_running_loop().create_connection(
|
||||
protocol_factory=lambda: self,
|
||||
host=self.host,
|
||||
port=self.port)
|
||||
except IOError as error:
|
||||
self.logger.warning('connection to %s:%s failed: %s',
|
||||
self.host, self.port, error)
|
66
tests/helpers.py
Normal file
66
tests/helpers.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import asyncio
|
||||
import io
|
||||
|
||||
|
||||
class StatsdServer(asyncio.Protocol):
|
||||
def __init__(self):
|
||||
self.service = None
|
||||
self.host = '127.0.0.1'
|
||||
self.port = 0
|
||||
self.connections_made = 0
|
||||
self.connections_lost = 0
|
||||
self.message_counter = 0
|
||||
|
||||
self.buffer = io.BytesIO()
|
||||
self.running = asyncio.Event()
|
||||
self.client_connected = asyncio.Semaphore(value=0)
|
||||
self.message_received = asyncio.Semaphore(value=0)
|
||||
self.transports: list[asyncio.Transport] = []
|
||||
|
||||
async def run(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
self.service = await loop.create_server(lambda: self,
|
||||
self.host,
|
||||
self.port,
|
||||
reuse_port=True)
|
||||
listening_sock = self.service.sockets[0]
|
||||
self.host, self.port = listening_sock.getsockname()
|
||||
self.running.set()
|
||||
try:
|
||||
await self.service.serve_forever()
|
||||
self.running.clear()
|
||||
except asyncio.CancelledError:
|
||||
self.close()
|
||||
await self.service.wait_closed()
|
||||
except Exception as error:
|
||||
raise error
|
||||
|
||||
def close(self):
|
||||
self.running.clear()
|
||||
self.service.close()
|
||||
for connected_client in self.transports:
|
||||
connected_client.close()
|
||||
self.transports.clear()
|
||||
|
||||
async def wait_running(self):
|
||||
await self.running.wait()
|
||||
|
||||
async def wait_closed(self):
|
||||
if self.service.is_serving():
|
||||
self.close()
|
||||
await self.service.wait_closed()
|
||||
while self.running.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
self.client_connected.release()
|
||||
self.connections_made += 1
|
||||
self.transports.append(transport)
|
||||
|
||||
def connection_lost(self, exc) -> None:
|
||||
self.connections_lost += 1
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
self.buffer.write(data)
|
||||
self.message_received.release()
|
||||
self.message_counter += 1
|
112
tests/test_processor.py
Normal file
112
tests/test_processor.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import asyncio
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from sprockets_statsd import statsd
|
||||
|
||||
from tests import helpers
|
||||
|
||||
|
||||
class ProcessorTests(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_timeout = 5.0
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.statsd_server = helpers.StatsdServer()
|
||||
self.statsd_task = asyncio.create_task(self.statsd_server.run())
|
||||
await self.statsd_server.wait_running()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
self.statsd_task.cancel()
|
||||
await self.statsd_server.wait_closed()
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def wait_for(self, fut):
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout=self.test_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.fail('future took too long to resolve')
|
||||
|
||||
async def test_that_processor_connects_and_disconnects(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
await self.wait_for(processor.stop())
|
||||
|
||||
self.assertEqual(1, self.statsd_server.connections_made)
|
||||
self.assertEqual(1, self.statsd_server.connections_lost)
|
||||
|
||||
async def test_that_processor_reconnects(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
# Now that the server is running and the client has connected,
|
||||
# cancel the server and let it die off.
|
||||
self.statsd_server.close()
|
||||
await self.statsd_server.wait_closed()
|
||||
until = time.time() + self.test_timeout
|
||||
while processor.connected.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
if time.time() >= until:
|
||||
self.fail('processor never disconnected')
|
||||
|
||||
# Start the server on the same port and let the client reconnect.
|
||||
self.statsd_task = asyncio.create_task(self.statsd_server.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
self.assertTrue(processor.connected.is_set())
|
||||
|
||||
await self.wait_for(processor.stop())
|
||||
|
||||
async def test_that_processor_can_be_cancelled(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
task = asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
task.cancel()
|
||||
await self.wait_for(processor.closed.wait())
|
||||
|
||||
async def test_shutdown_when_disconnected(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
self.statsd_server.close()
|
||||
await self.statsd_server.wait_closed()
|
||||
|
||||
await self.wait_for(processor.stop())
|
||||
|
||||
async def test_socket_resets(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
self.statsd_server.transports[0].close()
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
async def test_connection_failures(self):
|
||||
processor = statsd.Processor(host=self.statsd_server.host,
|
||||
port=self.statsd_server.port)
|
||||
asyncio.create_task(processor.run())
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
||||
|
||||
# Change the port and close the transport, this will cause the
|
||||
# processor to reconnect to the new port and fail.
|
||||
processor.port = 1
|
||||
processor.transport.close()
|
||||
|
||||
# Wait for the processor to be disconnected, then change the
|
||||
# port back and let the processor reconnect.
|
||||
while processor.connected.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.2)
|
||||
processor.port = self.statsd_server.port
|
||||
|
||||
await self.wait_for(self.statsd_server.client_connected.acquire())
|
Loading…
Reference in a new issue