import asyncio
import io
import socket
import typing


class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
    metrics: typing.List[bytes]

    def __init__(self, ip_protocol):
        self.server = None
        self.host = '127.0.0.1'
        self.port = 0
        self.ip_protocol = ip_protocol
        self.connections_made = 0
        self.connections_lost = 0
        self.message_counter = 0

        self.metrics = []
        self.running = asyncio.Event()
        self.client_connected = asyncio.Semaphore(value=0)
        self.message_received = asyncio.Semaphore(value=0)
        self.transports: typing.List[asyncio.BaseTransport] = []

        self._buffer = io.BytesIO()

    async def run(self):
        await self._reset()

        loop = asyncio.get_running_loop()
        if self.ip_protocol == socket.IPPROTO_TCP:
            server = await loop.create_server(lambda: self,
                                              self.host,
                                              self.port,
                                              reuse_port=True)
            self.server = server
            listening_sock = typing.cast(typing.List[socket.socket],
                                         server.sockets)[0]
            self.host, self.port = listening_sock.getsockname()
            self.running.set()
            try:
                await server.serve_forever()
            except asyncio.CancelledError:
                self.close()
                await server.wait_closed()
            except Exception as error:
                raise error
            finally:
                self.running.clear()

        elif self.ip_protocol == socket.IPPROTO_UDP:
            transport, protocol = await loop.create_datagram_endpoint(
                lambda: self,
                local_addr=(self.host, self.port),
                reuse_port=True)
            self.server = transport
            self.host, self.port = transport.get_extra_info('sockname')
            self.running.set()
            try:
                while not transport.is_closing():
                    await asyncio.sleep(0.1)
            finally:
                self.running.clear()

    def close(self):
        if self.server is not None:
            self.server.close()
        for connected_client in self.transports:
            connected_client.close()
        self.transports.clear()

    async def wait_running(self):
        await self.running.wait()

    async def wait_closed(self):
        while self.running.is_set():
            await asyncio.sleep(0.1)

    def connection_made(self, transport: asyncio.BaseTransport):
        self.client_connected.release()
        self.connections_made += 1
        self.transports.append(transport)

    def connection_lost(self, exc) -> None:
        self.connections_lost += 1

    def data_received(self, data: bytes):
        self._buffer.write(data)
        self._process_buffer()

    def datagram_received(self, data: bytes, _addr):
        self._buffer.write(data + b'\n')
        self._process_buffer()

    def _process_buffer(self):
        buf = self._buffer.getvalue()
        if b'\n' in buf:
            buf_complete = buf[-1] == ord('\n')
            if not buf_complete:
                offset = buf.rfind(b'\n')
                self._buffer = io.BytesIO(buf[offset:])
                buf = buf[:offset]
            else:
                self._buffer = io.BytesIO()
                buf = buf[:-1]

            for metric in buf.split(b'\n'):
                self.metrics.append(metric)
                self.message_received.release()
                self.message_counter += 1

    async def _reset(self):
        self._buffer = io.BytesIO()
        self.connections_made = 0
        self.connections_lost = 0
        self.message_counter = 0
        self.metrics.clear()
        for transport in self.transports:
            transport.close()
        self.transports.clear()

        self.running.clear()
        await self._drain_semaphore(self.client_connected)
        await self._drain_semaphore(self.message_received)

    @staticmethod
    async def _drain_semaphore(semaphore: asyncio.Semaphore):
        while not semaphore.locked():
            try:
                await asyncio.wait_for(semaphore.acquire(), 0.1)
            except asyncio.TimeoutError:
                break