mirror of
https://github.com/sprockets/sprockets-statsd.git
synced 2024-11-14 19:29:30 +00:00
Clean up type annotations.
This almost makes mypy happy. I did manage to step on a defect that there really isn't a great workaround for so I resorted to disabling typing there. https://github.com/python/mypy/issues/2427
This commit is contained in:
parent
b84e52592d
commit
2ecdee61c4
6 changed files with 102 additions and 64 deletions
15
setup.cfg
15
setup.cfg
|
@ -68,6 +68,21 @@ application_import_names = sprockets_statsd,tests
|
|||
exclude = build,env,dist
|
||||
import_order_style = pycharm
|
||||
|
||||
[mypy]
|
||||
cache_dir = build/mypy-cache
|
||||
check_untyped_defs = true
|
||||
show_error_codes = true
|
||||
warn_no_return = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_configs = true
|
||||
warn_unused_ignores = true
|
||||
|
||||
[mypy-sprockets_statsd]
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
strict = true
|
||||
|
||||
[yapf]
|
||||
allow_split_before_dict_value = false
|
||||
indent_dictionary_value = true
|
||||
|
|
|
@ -35,11 +35,16 @@ class Connector:
|
|||
sends the metric payloads.
|
||||
|
||||
"""
|
||||
def __init__(self, host: str, port: int = 8125, **kwargs):
|
||||
self.processor = Processor(host=host, port=port, **kwargs)
|
||||
self._processor_task = None
|
||||
processor: 'Processor'
|
||||
|
||||
async def start(self):
|
||||
def __init__(self,
|
||||
host: str,
|
||||
port: int = 8125,
|
||||
**kwargs: typing.Any) -> None:
|
||||
self.processor = Processor(host=host, port=port, **kwargs)
|
||||
self._processor_task: typing.Optional[asyncio.Task[None]] = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the processor in the background.
|
||||
|
||||
This is a *blocking* method and does not return until the
|
||||
|
@ -49,7 +54,7 @@ class Connector:
|
|||
self._processor_task = asyncio.create_task(self.processor.run())
|
||||
await self.processor.running.wait()
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background processor.
|
||||
|
||||
Items that are currently in the queue will be flushed to
|
||||
|
@ -60,7 +65,7 @@ class Connector:
|
|||
"""
|
||||
await self.processor.stop()
|
||||
|
||||
def incr(self, path: str, value: int = 1):
|
||||
def incr(self, path: str, value: int = 1) -> None:
|
||||
"""Increment a counter metric.
|
||||
|
||||
:param path: counter to increment
|
||||
|
@ -69,7 +74,7 @@ class Connector:
|
|||
"""
|
||||
self.inject_metric(path, str(value), 'c')
|
||||
|
||||
def decr(self, path: str, value: int = 1):
|
||||
def decr(self, path: str, value: int = 1) -> None:
|
||||
"""Decrement a counter metric.
|
||||
|
||||
:param path: counter to decrement
|
||||
|
@ -80,7 +85,7 @@ class Connector:
|
|||
"""
|
||||
self.inject_metric(path, str(-value), 'c')
|
||||
|
||||
def gauge(self, path: str, value: int, delta: bool = False):
|
||||
def gauge(self, path: str, value: int, delta: bool = False) -> None:
|
||||
"""Manipulate a gauge metric.
|
||||
|
||||
:param path: gauge to adjust
|
||||
|
@ -93,26 +98,25 @@ class Connector:
|
|||
|
||||
"""
|
||||
if delta:
|
||||
value = f'{value:+d}'
|
||||
payload = f'{value:+d}'
|
||||
else:
|
||||
value = str(value)
|
||||
self.inject_metric(path, value, 'g')
|
||||
payload = str(value)
|
||||
self.inject_metric(path, payload, 'g')
|
||||
|
||||
def timing(self, path: str, seconds: float):
|
||||
def timing(self, path: str, seconds: float) -> None:
|
||||
"""Send a timer metric.
|
||||
|
||||
:param path: timer to append a value to
|
||||
:param seconds: number of **seconds** to record
|
||||
|
||||
"""
|
||||
self.inject_metric(path, seconds * 1000.0, 'ms')
|
||||
self.inject_metric(path, str(seconds * 1000.0), 'ms')
|
||||
|
||||
def inject_metric(self, path: str, value, type_code: str):
|
||||
def inject_metric(self, path: str, value: str, type_code: str) -> None:
|
||||
"""Send a metric to the statsd server.
|
||||
|
||||
:param path: formatted metric name
|
||||
:param value: metric value as a number or a string. The
|
||||
string form is required for relative gauges.
|
||||
:param value: formatted metric value
|
||||
:param type_code: type of the metric to send
|
||||
|
||||
This method formats the payload and inserts it on the
|
||||
|
@ -120,10 +124,10 @@ class Connector:
|
|||
|
||||
"""
|
||||
payload = f'{path}:{value}|{type_code}'
|
||||
self.processor.queue.put_nowait(payload.encode('utf-8'))
|
||||
self.processor.enqueue(payload.encode('utf-8'))
|
||||
|
||||
|
||||
class StatsdProtocol:
|
||||
class StatsdProtocol(asyncio.BaseProtocol):
|
||||
"""Common interface for backend protocols/transports.
|
||||
|
||||
UDP and TCP transports have different interfaces (sendto vs write)
|
||||
|
@ -145,9 +149,12 @@ class StatsdProtocol:
|
|||
Is the protocol currently connected?
|
||||
|
||||
"""
|
||||
buffered_data: bytes
|
||||
ip_protocol: int = socket.IPPROTO_NONE
|
||||
logger: logging.Logger
|
||||
transport: typing.Optional[asyncio.BaseTransport]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.buffered_data = b''
|
||||
self.connected = asyncio.Event()
|
||||
self.logger = logging.getLogger(__package__).getChild(
|
||||
|
@ -162,15 +169,16 @@ class StatsdProtocol:
|
|||
"""Shutdown the transport and wait for it to close."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Capture the new transport and set the connected event."""
|
||||
# NB - this will return a 4-part tuple in some cases
|
||||
server, port = transport.get_extra_info('peername')[:2]
|
||||
self.logger.info('connected to statsd %s:%s', server, port)
|
||||
self.transport = transport
|
||||
self.transport.set_protocol(self)
|
||||
self.connected.set()
|
||||
|
||||
def connection_lost(self, exc: typing.Optional[Exception]):
|
||||
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
|
||||
"""Clear the connected event."""
|
||||
self.logger.warning('statsd server connection lost: %s', exc)
|
||||
self.connected.clear()
|
||||
|
@ -178,9 +186,10 @@ class StatsdProtocol:
|
|||
|
||||
class TCPProtocol(StatsdProtocol, asyncio.Protocol):
|
||||
"""StatsdProtocol implementation over a TCP/IP connection."""
|
||||
ip_protocol = socket.IPPROTO_TCP
|
||||
transport: asyncio.WriteTransport
|
||||
|
||||
def eof_received(self):
|
||||
def eof_received(self) -> None:
|
||||
self.logger.warning('received EOF from statsd server')
|
||||
self.connected.clear()
|
||||
|
||||
|
@ -219,6 +228,7 @@ class TCPProtocol(StatsdProtocol, asyncio.Protocol):
|
|||
|
||||
class UDPProtocol(StatsdProtocol, asyncio.DatagramProtocol):
|
||||
"""StatsdProtocol implementation over a UDP/IP connection."""
|
||||
ip_protocol = socket.IPPROTO_UDP
|
||||
transport: asyncio.DatagramTransport
|
||||
|
||||
def send(self, metric: bytes) -> None:
|
||||
|
@ -286,18 +296,20 @@ class Processor:
|
|||
|
||||
"""
|
||||
|
||||
protocol: typing.Union[StatsdProtocol, None]
|
||||
logger: logging.Logger
|
||||
protocol: typing.Optional[StatsdProtocol]
|
||||
queue: asyncio.Queue[bytes]
|
||||
_create_transport: typing.Callable[[], typing.Coroutine[
|
||||
typing.Any, typing.Any, typing.Tuple[asyncio.BaseTransport,
|
||||
StatsdProtocol]]]
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
host,
|
||||
host: str,
|
||||
port: int = 8125,
|
||||
reconnect_sleep: float = 1.0,
|
||||
ip_protocol: int = socket.IPPROTO_TCP,
|
||||
wait_timeout: float = 0.1):
|
||||
wait_timeout: float = 0.1) -> None:
|
||||
super().__init__()
|
||||
if not host:
|
||||
raise RuntimeError('host must be set')
|
||||
|
@ -314,9 +326,11 @@ class Processor:
|
|||
socket.IPPROTO_UDP: self._create_udp_transport,
|
||||
}
|
||||
try:
|
||||
self._create_transport = transport_creators[ip_protocol]
|
||||
factory = transport_creators[ip_protocol]
|
||||
except KeyError:
|
||||
raise RuntimeError(f'ip_protocol {ip_protocol} is not supported')
|
||||
else:
|
||||
self._create_transport = factory # type: ignore
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
@ -332,16 +346,15 @@ class Processor:
|
|||
self.protocol = None
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
self._failed_sends = []
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
"""Is the processor connected?"""
|
||||
if self.protocol is None:
|
||||
return False
|
||||
return self.protocol.connected.is_set()
|
||||
return self.protocol is not None and self.protocol.connected.is_set()
|
||||
|
||||
async def run(self):
|
||||
def enqueue(self, metric: bytes) -> None:
|
||||
self.queue.put_nowait(metric)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Maintains the connection and processes metric payloads."""
|
||||
self.running.set()
|
||||
self.stopped.clear()
|
||||
|
@ -374,7 +387,7 @@ class Processor:
|
|||
self.running.clear()
|
||||
self.stopped.set()
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop the processor.
|
||||
|
||||
This is an asynchronous but blocking method. It does not
|
||||
|
@ -399,7 +412,7 @@ class Processor:
|
|||
reuse_port=True)
|
||||
return t, typing.cast(StatsdProtocol, p)
|
||||
|
||||
async def _connect_if_necessary(self):
|
||||
async def _connect_if_necessary(self) -> None:
|
||||
if self.protocol is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self.protocol.connected.wait(),
|
||||
|
@ -412,7 +425,8 @@ class Processor:
|
|||
buffered_data = b''
|
||||
if self.protocol is not None:
|
||||
buffered_data = self.protocol.buffered_data
|
||||
transport, self.protocol = await self._create_transport()
|
||||
t, p = await self._create_transport() # type: ignore[misc]
|
||||
transport, self.protocol = t, p
|
||||
self.protocol.buffered_data = buffered_data
|
||||
self.logger.info('connection established to %s',
|
||||
transport.get_extra_info('peername'))
|
||||
|
@ -421,7 +435,7 @@ class Processor:
|
|||
self.host, self.port, error)
|
||||
await asyncio.sleep(self._reconnect_sleep)
|
||||
|
||||
async def _process_metric(self):
|
||||
async def _process_metric(self) -> None:
|
||||
try:
|
||||
metric = await asyncio.wait_for(self.queue.get(),
|
||||
self._wait_timeout)
|
||||
|
@ -432,4 +446,5 @@ class Processor:
|
|||
# it has queued metrics to send
|
||||
metric = b''
|
||||
|
||||
assert self.protocol is not None # AFAICT, this cannot happen
|
||||
self.protocol.send(metric)
|
||||
|
|
|
@ -2,6 +2,7 @@ import contextlib
|
|||
import os
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
|
||||
from tornado import web
|
||||
|
||||
|
@ -71,7 +72,9 @@ class Application(web.Application):
|
|||
processor quickly responds to connection faults.
|
||||
|
||||
"""
|
||||
def __init__(self, *args, **settings):
|
||||
statsd_connector: typing.Optional[statsd.Connector]
|
||||
|
||||
def __init__(self, *args: typing.Any, **settings: typing.Any):
|
||||
statsd_settings = settings.setdefault('statsd', {})
|
||||
statsd_settings.setdefault('host', os.environ.get('STATSD_HOST'))
|
||||
statsd_settings.setdefault('port',
|
||||
|
@ -98,7 +101,7 @@ class Application(web.Application):
|
|||
self.settings['statsd']['port'] = int(self.settings['statsd']['port'])
|
||||
self.statsd_connector = None
|
||||
|
||||
async def start_statsd(self):
|
||||
async def start_statsd(self) -> None:
|
||||
"""Start the connector during startup.
|
||||
|
||||
Call this method during application startup to enable the statsd
|
||||
|
@ -130,7 +133,7 @@ class Application(web.Application):
|
|||
self.statsd_connector = statsd.Connector(**kwargs)
|
||||
await self.statsd_connector.start()
|
||||
|
||||
async def stop_statsd(self):
|
||||
async def stop_statsd(self) -> None:
|
||||
"""Stop the connector during shutdown.
|
||||
|
||||
If the connector was started, then this method will gracefully
|
||||
|
@ -145,20 +148,20 @@ class Application(web.Application):
|
|||
|
||||
class RequestHandler(web.RequestHandler):
|
||||
"""Mix this into your handler to send metrics to a statsd server."""
|
||||
statsd_connector: statsd.Connector
|
||||
statsd_connector: typing.Optional[statsd.Connector]
|
||||
|
||||
def initialize(self, **kwargs):
|
||||
def initialize(self, **kwargs: typing.Any) -> None:
|
||||
super().initialize(**kwargs)
|
||||
self.application: Application
|
||||
self.statsd_connector = self.application.statsd_connector
|
||||
|
||||
def __build_path(self, *path):
|
||||
def __build_path(self, *path: typing.Any) -> str:
|
||||
full_path = '.'.join(str(c) for c in path)
|
||||
if self.settings.get('statsd', {}).get('prefix', ''):
|
||||
return f'{self.settings["statsd"]["prefix"]}.{full_path}'
|
||||
return full_path
|
||||
|
||||
def record_timing(self, secs: float, *path):
|
||||
def record_timing(self, secs: float, *path: typing.Any) -> None:
|
||||
"""Record the duration.
|
||||
|
||||
:param secs: number of seconds to record
|
||||
|
@ -169,7 +172,7 @@ class RequestHandler(web.RequestHandler):
|
|||
self.statsd_connector.timing(self.__build_path('timers', *path),
|
||||
secs)
|
||||
|
||||
def increase_counter(self, *path, amount: int = 1):
|
||||
def increase_counter(self, *path: typing.Any, amount: int = 1) -> None:
|
||||
"""Adjust a counter.
|
||||
|
||||
:param path: path of the counter to adjust
|
||||
|
@ -182,7 +185,8 @@ class RequestHandler(web.RequestHandler):
|
|||
amount)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def execution_timer(self, *path):
|
||||
def execution_timer(
|
||||
self, *path: typing.Any) -> typing.Generator[None, None, None]:
|
||||
"""Record the execution duration of a block of code.
|
||||
|
||||
:param path: path to record the duration as
|
||||
|
@ -194,7 +198,7 @@ class RequestHandler(web.RequestHandler):
|
|||
finally:
|
||||
self.record_timing(time.time() - start, *path)
|
||||
|
||||
def on_finish(self):
|
||||
def on_finish(self) -> None:
|
||||
"""Extended to record the request time as a duration.
|
||||
|
||||
This method extends :meth:`tornado.web.RequestHandler.on_finish`
|
||||
|
|
|
@ -20,7 +20,7 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
|
|||
self.running = asyncio.Event()
|
||||
self.client_connected = asyncio.Semaphore(value=0)
|
||||
self.message_received = asyncio.Semaphore(value=0)
|
||||
self.transports: list[asyncio.Transport] = []
|
||||
self.transports: list[asyncio.BaseTransport] = []
|
||||
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
|
@ -34,7 +34,8 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
|
|||
self.port,
|
||||
reuse_port=True)
|
||||
self.server = server
|
||||
listening_sock = server.sockets[0]
|
||||
listening_sock = typing.cast(list[socket.socket],
|
||||
server.sockets)[0]
|
||||
self.host, self.port = listening_sock.getsockname()
|
||||
self.running.set()
|
||||
try:
|
||||
|
@ -62,7 +63,8 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
|
|||
self.running.clear()
|
||||
|
||||
def close(self):
|
||||
self.server.close()
|
||||
if self.server is not None:
|
||||
self.server.close()
|
||||
for connected_client in self.transports:
|
||||
connected_client.close()
|
||||
self.transports.clear()
|
||||
|
@ -74,7 +76,7 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
|
|||
while self.running.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def connection_made(self, transport: asyncio.Transport):
|
||||
def connection_made(self, transport: asyncio.BaseTransport):
|
||||
self.client_connected.release()
|
||||
self.connections_made += 1
|
||||
self.transports.append(transport)
|
||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import logging
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
|
||||
import asynctest
|
||||
|
||||
|
@ -110,7 +111,7 @@ class ProcessorTests(ProcessorTestCase):
|
|||
|
||||
def test_that_processor_fails_when_host_is_none(self):
|
||||
with self.assertRaises(RuntimeError) as context:
|
||||
statsd.Processor(host=None, port=12345)
|
||||
statsd.Processor(host=None, port=12345) # type: ignore[arg-type]
|
||||
self.assertIn('host', str(context.exception))
|
||||
|
||||
async def test_starting_and_stopping_without_connecting(self):
|
||||
|
@ -129,7 +130,7 @@ class ProcessorTests(ProcessorTestCase):
|
|||
await self.wait_for(processor.running.wait())
|
||||
|
||||
with self.assertLogs(processor.logger, level=logging.ERROR) as cm:
|
||||
processor.queue.put_nowait('not-bytes')
|
||||
processor.queue.put_nowait('not-bytes') # type: ignore[arg-type]
|
||||
while processor.queue.qsize() > 0:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
@ -189,15 +190,16 @@ class TCPProcessingTests(ProcessorTestCase):
|
|||
|
||||
async def test_socket_closure_while_sending(self):
|
||||
state = {'first_time': True}
|
||||
real_transport_write = self.processor.protocol.transport.write
|
||||
protocol = typing.cast(statsd.TCPProtocol, self.processor.protocol)
|
||||
real_transport_write = protocol.transport.write
|
||||
|
||||
def fake_transport_write(buffer):
|
||||
def fake_transport_write(data):
|
||||
if state['first_time']:
|
||||
self.processor.protocol.transport.close()
|
||||
state['first_time'] = False
|
||||
return real_transport_write(buffer)
|
||||
return real_transport_write(data)
|
||||
|
||||
self.processor.protocol.transport.write = fake_transport_write
|
||||
protocol.transport.write = fake_transport_write
|
||||
self.processor.queue.put_nowait(b'counter:1|c')
|
||||
await self.wait_for(self.statsd_server.message_received.acquire())
|
||||
|
||||
|
@ -265,8 +267,8 @@ class ConnectorTests(ProcessorTestCase):
|
|||
await super().asyncTearDown()
|
||||
|
||||
def assert_metrics_equal(self, recvd: bytes, path, value, type_code):
|
||||
recvd = recvd.decode('utf-8')
|
||||
recvd_path, _, rest = recvd.partition(':')
|
||||
decoded = recvd.decode('utf-8')
|
||||
recvd_path, _, rest = decoded.partition(':')
|
||||
recvd_value, _, recvd_code = rest.partition('|')
|
||||
self.assertEqual(path, recvd_path, 'metric path mismatch')
|
||||
self.assertEqual(recvd_value, str(value), 'metric value mismatch')
|
||||
|
|
|
@ -257,19 +257,19 @@ class RequestHandlerTests(AsyncTestCaseWithTimeout, testing.AsyncHTTPTestCase):
|
|||
timeout_remaining -= (time.time() - start)
|
||||
|
||||
def parse_metric(self, metric_line: bytes) -> ParsedMetric:
|
||||
metric_line = metric_line.decode()
|
||||
path, _, rest = metric_line.partition(':')
|
||||
decoded = metric_line.decode()
|
||||
path, _, rest = decoded.partition(':')
|
||||
value, _, type_code = rest.partition('|')
|
||||
try:
|
||||
value = float(value)
|
||||
parsed_value = float(value)
|
||||
except ValueError:
|
||||
self.fail(f'value of {path} is not a number: value={value!r}')
|
||||
return path, value, type_code
|
||||
return path, parsed_value, type_code
|
||||
|
||||
def find_metric(self, needle: str) -> ParsedMetric:
|
||||
needle = needle.encode()
|
||||
encoded = needle.encode()
|
||||
for line in self.statsd_server.metrics:
|
||||
if needle in line:
|
||||
if encoded in line:
|
||||
return self.parse_metric(line)
|
||||
self.fail(f'failed to find metric containing {needle!r}')
|
||||
|
||||
|
|
Loading…
Reference in a new issue