import asyncio import io import socket import typing class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol): metrics: typing.List[bytes] def __init__(self, ip_protocol): self.server = None self.host = '127.0.0.1' self.port = 0 self.ip_protocol = ip_protocol self.connections_made = 0 self.connections_lost = 0 self.message_counter = 0 self.metrics = [] self.running = asyncio.Event() self.client_connected = asyncio.Semaphore(value=0) self.message_received = asyncio.Semaphore(value=0) self.transports: typing.List[asyncio.BaseTransport] = [] self._buffer = io.BytesIO() async def run(self): await self._reset() loop = asyncio.get_running_loop() if self.ip_protocol == socket.IPPROTO_TCP: server = await loop.create_server(lambda: self, self.host, self.port, reuse_port=True) self.server = server listening_sock = typing.cast(typing.List[socket.socket], server.sockets)[0] self.host, self.port = listening_sock.getsockname() self.running.set() try: await server.serve_forever() except asyncio.CancelledError: self.close() await server.wait_closed() except Exception as error: raise error finally: self.running.clear() elif self.ip_protocol == socket.IPPROTO_UDP: transport, protocol = await loop.create_datagram_endpoint( lambda: self, local_addr=(self.host, self.port), reuse_port=True) self.server = transport self.host, self.port = transport.get_extra_info('sockname') self.running.set() try: while not transport.is_closing(): await asyncio.sleep(0.1) finally: self.running.clear() def close(self): if self.server is not None: self.server.close() for connected_client in self.transports: connected_client.close() self.transports.clear() async def wait_running(self): await self.running.wait() async def wait_closed(self): while self.running.is_set(): await asyncio.sleep(0.1) def connection_made(self, transport: asyncio.BaseTransport): self.client_connected.release() self.connections_made += 1 self.transports.append(transport) def connection_lost(self, exc) -> None: self.connections_lost += 1 def data_received(self, data: bytes): self._buffer.write(data) self._process_buffer() def datagram_received(self, data: bytes, _addr): self._buffer.write(data + b'\n') self._process_buffer() def _process_buffer(self): buf = self._buffer.getvalue() if b'\n' in buf: buf_complete = buf[-1] == ord('\n') if not buf_complete: offset = buf.rfind(b'\n') self._buffer = io.BytesIO(buf[offset:]) buf = buf[:offset] else: self._buffer = io.BytesIO() buf = buf[:-1] for metric in buf.split(b'\n'): self.metrics.append(metric) self.message_received.release() self.message_counter += 1 async def _reset(self): self._buffer = io.BytesIO() self.connections_made = 0 self.connections_lost = 0 self.message_counter = 0 self.metrics.clear() for transport in self.transports: transport.close() self.transports.clear() self.running.clear() await self._drain_semaphore(self.client_connected) await self._drain_semaphore(self.message_received) @staticmethod async def _drain_semaphore(semaphore: asyncio.Semaphore): while not semaphore.locked(): try: await asyncio.wait_for(semaphore.acquire(), 0.1) except asyncio.TimeoutError: break