sprockets-statsd/tests/helpers.py
2021-03-30 08:09:53 -04:00

132 lines
4.3 KiB
Python

import asyncio
import io
import socket
import typing
class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
metrics: typing.List[bytes]
def __init__(self, ip_protocol):
self.server = None
self.host = '127.0.0.1'
self.port = 0
self.ip_protocol = ip_protocol
self.connections_made = 0
self.connections_lost = 0
self.message_counter = 0
self.metrics = []
self.running = asyncio.Event()
self.client_connected = asyncio.Semaphore(value=0)
self.message_received = asyncio.Semaphore(value=0)
self.transports: typing.List[asyncio.BaseTransport] = []
self._buffer = io.BytesIO()
async def run(self):
await self._reset()
loop = asyncio.get_running_loop()
if self.ip_protocol == socket.IPPROTO_TCP:
server = await loop.create_server(lambda: self,
self.host,
self.port,
reuse_port=True)
self.server = server
listening_sock = typing.cast(typing.List[socket.socket],
server.sockets)[0]
self.host, self.port = listening_sock.getsockname()
self.running.set()
try:
await server.serve_forever()
except asyncio.CancelledError:
self.close()
await server.wait_closed()
except Exception as error:
raise error
finally:
self.running.clear()
elif self.ip_protocol == socket.IPPROTO_UDP:
transport, protocol = await loop.create_datagram_endpoint(
lambda: self,
local_addr=(self.host, self.port),
reuse_port=True)
self.server = transport
self.host, self.port = transport.get_extra_info('sockname')
self.running.set()
try:
while not transport.is_closing():
await asyncio.sleep(0.1)
finally:
self.running.clear()
def close(self):
if self.server is not None:
self.server.close()
for connected_client in self.transports:
connected_client.close()
self.transports.clear()
async def wait_running(self):
await self.running.wait()
async def wait_closed(self):
while self.running.is_set():
await asyncio.sleep(0.1)
def connection_made(self, transport: asyncio.BaseTransport):
self.client_connected.release()
self.connections_made += 1
self.transports.append(transport)
def connection_lost(self, exc) -> None:
self.connections_lost += 1
def data_received(self, data: bytes):
self._buffer.write(data)
self._process_buffer()
def datagram_received(self, data: bytes, _addr):
self._buffer.write(data + b'\n')
self._process_buffer()
def _process_buffer(self):
buf = self._buffer.getvalue()
if b'\n' in buf:
buf_complete = buf[-1] == ord('\n')
if not buf_complete:
offset = buf.rfind(b'\n')
self._buffer = io.BytesIO(buf[offset:])
buf = buf[:offset]
else:
self._buffer = io.BytesIO()
buf = buf[:-1]
for metric in buf.split(b'\n'):
self.metrics.append(metric)
self.message_received.release()
self.message_counter += 1
async def _reset(self):
self._buffer = io.BytesIO()
self.connections_made = 0
self.connections_lost = 0
self.message_counter = 0
self.metrics.clear()
for transport in self.transports:
transport.close()
self.transports.clear()
self.running.clear()
await self._drain_semaphore(self.client_connected)
await self._drain_semaphore(self.message_received)
@staticmethod
async def _drain_semaphore(semaphore: asyncio.Semaphore):
while not semaphore.locked():
try:
await asyncio.wait_for(semaphore.acquire(), 0.1)
except asyncio.TimeoutError:
break