Clean up type annotations.

This almost makes mypy happy.  I did manage to step on a defect that
there really isn't a great workaround for so I resorted to disabling
typing there.

https://github.com/python/mypy/issues/2427
This commit is contained in:
Dave Shawley 2021-03-26 06:40:27 -04:00
parent b84e52592d
commit 2ecdee61c4
No known key found for this signature in database
GPG key ID: 44A9C9992CCFAB82
6 changed files with 102 additions and 64 deletions

View file

@ -68,6 +68,21 @@ application_import_names = sprockets_statsd,tests
exclude = build,env,dist
import_order_style = pycharm
[mypy]
cache_dir = build/mypy-cache
check_untyped_defs = true
show_error_codes = true
warn_no_return = true
warn_redundant_casts = true
warn_unused_configs = true
warn_unused_ignores = true
[mypy-sprockets_statsd]
disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
strict = true
[yapf]
allow_split_before_dict_value = false
indent_dictionary_value = true

View file

@ -35,11 +35,16 @@ class Connector:
sends the metric payloads.
"""
def __init__(self, host: str, port: int = 8125, **kwargs):
self.processor = Processor(host=host, port=port, **kwargs)
self._processor_task = None
processor: 'Processor'
async def start(self):
def __init__(self,
host: str,
port: int = 8125,
**kwargs: typing.Any) -> None:
self.processor = Processor(host=host, port=port, **kwargs)
self._processor_task: typing.Optional[asyncio.Task[None]] = None
async def start(self) -> None:
"""Start the processor in the background.
This is a *blocking* method and does not return until the
@ -49,7 +54,7 @@ class Connector:
self._processor_task = asyncio.create_task(self.processor.run())
await self.processor.running.wait()
async def stop(self):
async def stop(self) -> None:
"""Stop the background processor.
Items that are currently in the queue will be flushed to
@ -60,7 +65,7 @@ class Connector:
"""
await self.processor.stop()
def incr(self, path: str, value: int = 1):
def incr(self, path: str, value: int = 1) -> None:
"""Increment a counter metric.
:param path: counter to increment
@ -69,7 +74,7 @@ class Connector:
"""
self.inject_metric(path, str(value), 'c')
def decr(self, path: str, value: int = 1):
def decr(self, path: str, value: int = 1) -> None:
"""Decrement a counter metric.
:param path: counter to decrement
@ -80,7 +85,7 @@ class Connector:
"""
self.inject_metric(path, str(-value), 'c')
def gauge(self, path: str, value: int, delta: bool = False):
def gauge(self, path: str, value: int, delta: bool = False) -> None:
"""Manipulate a gauge metric.
:param path: gauge to adjust
@ -93,26 +98,25 @@ class Connector:
"""
if delta:
value = f'{value:+d}'
payload = f'{value:+d}'
else:
value = str(value)
self.inject_metric(path, value, 'g')
payload = str(value)
self.inject_metric(path, payload, 'g')
def timing(self, path: str, seconds: float):
def timing(self, path: str, seconds: float) -> None:
"""Send a timer metric.
:param path: timer to append a value to
:param seconds: number of **seconds** to record
"""
self.inject_metric(path, seconds * 1000.0, 'ms')
self.inject_metric(path, str(seconds * 1000.0), 'ms')
def inject_metric(self, path: str, value, type_code: str):
def inject_metric(self, path: str, value: str, type_code: str) -> None:
"""Send a metric to the statsd server.
:param path: formatted metric name
:param value: metric value as a number or a string. The
string form is required for relative gauges.
:param value: formatted metric value
:param type_code: type of the metric to send
This method formats the payload and inserts it on the
@ -120,10 +124,10 @@ class Connector:
"""
payload = f'{path}:{value}|{type_code}'
self.processor.queue.put_nowait(payload.encode('utf-8'))
self.processor.enqueue(payload.encode('utf-8'))
class StatsdProtocol:
class StatsdProtocol(asyncio.BaseProtocol):
"""Common interface for backend protocols/transports.
UDP and TCP transports have different interfaces (sendto vs write)
@ -145,9 +149,12 @@ class StatsdProtocol:
Is the protocol currently connected?
"""
buffered_data: bytes
ip_protocol: int = socket.IPPROTO_NONE
logger: logging.Logger
transport: typing.Optional[asyncio.BaseTransport]
def __init__(self):
def __init__(self) -> None:
self.buffered_data = b''
self.connected = asyncio.Event()
self.logger = logging.getLogger(__package__).getChild(
@ -162,15 +169,16 @@ class StatsdProtocol:
"""Shutdown the transport and wait for it to close."""
raise NotImplementedError()
def connection_made(self, transport: asyncio.Transport):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Capture the new transport and set the connected event."""
# NB - this will return a 4-part tuple in some cases
server, port = transport.get_extra_info('peername')[:2]
self.logger.info('connected to statsd %s:%s', server, port)
self.transport = transport
self.transport.set_protocol(self)
self.connected.set()
def connection_lost(self, exc: typing.Optional[Exception]):
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
"""Clear the connected event."""
self.logger.warning('statsd server connection lost: %s', exc)
self.connected.clear()
@ -178,9 +186,10 @@ class StatsdProtocol:
class TCPProtocol(StatsdProtocol, asyncio.Protocol):
"""StatsdProtocol implementation over a TCP/IP connection."""
ip_protocol = socket.IPPROTO_TCP
transport: asyncio.WriteTransport
def eof_received(self):
def eof_received(self) -> None:
self.logger.warning('received EOF from statsd server')
self.connected.clear()
@ -219,6 +228,7 @@ class TCPProtocol(StatsdProtocol, asyncio.Protocol):
class UDPProtocol(StatsdProtocol, asyncio.DatagramProtocol):
"""StatsdProtocol implementation over a UDP/IP connection."""
ip_protocol = socket.IPPROTO_UDP
transport: asyncio.DatagramTransport
def send(self, metric: bytes) -> None:
@ -286,18 +296,20 @@ class Processor:
"""
protocol: typing.Union[StatsdProtocol, None]
logger: logging.Logger
protocol: typing.Optional[StatsdProtocol]
queue: asyncio.Queue[bytes]
_create_transport: typing.Callable[[], typing.Coroutine[
typing.Any, typing.Any, typing.Tuple[asyncio.BaseTransport,
StatsdProtocol]]]
def __init__(self,
*,
host,
host: str,
port: int = 8125,
reconnect_sleep: float = 1.0,
ip_protocol: int = socket.IPPROTO_TCP,
wait_timeout: float = 0.1):
wait_timeout: float = 0.1) -> None:
super().__init__()
if not host:
raise RuntimeError('host must be set')
@ -314,9 +326,11 @@ class Processor:
socket.IPPROTO_UDP: self._create_udp_transport,
}
try:
self._create_transport = transport_creators[ip_protocol]
factory = transport_creators[ip_protocol]
except KeyError:
raise RuntimeError(f'ip_protocol {ip_protocol} is not supported')
else:
self._create_transport = factory # type: ignore
self.host = host
self.port = port
@ -332,16 +346,15 @@ class Processor:
self.protocol = None
self.queue = asyncio.Queue()
self._failed_sends = []
@property
def connected(self) -> bool:
"""Is the processor connected?"""
if self.protocol is None:
return False
return self.protocol.connected.is_set()
return self.protocol is not None and self.protocol.connected.is_set()
async def run(self):
def enqueue(self, metric: bytes) -> None:
self.queue.put_nowait(metric)
async def run(self) -> None:
"""Maintains the connection and processes metric payloads."""
self.running.set()
self.stopped.clear()
@ -374,7 +387,7 @@ class Processor:
self.running.clear()
self.stopped.set()
async def stop(self):
async def stop(self) -> None:
"""Stop the processor.
This is an asynchronous but blocking method. It does not
@ -399,7 +412,7 @@ class Processor:
reuse_port=True)
return t, typing.cast(StatsdProtocol, p)
async def _connect_if_necessary(self):
async def _connect_if_necessary(self) -> None:
if self.protocol is not None:
try:
await asyncio.wait_for(self.protocol.connected.wait(),
@ -412,7 +425,8 @@ class Processor:
buffered_data = b''
if self.protocol is not None:
buffered_data = self.protocol.buffered_data
transport, self.protocol = await self._create_transport()
t, p = await self._create_transport() # type: ignore[misc]
transport, self.protocol = t, p
self.protocol.buffered_data = buffered_data
self.logger.info('connection established to %s',
transport.get_extra_info('peername'))
@ -421,7 +435,7 @@ class Processor:
self.host, self.port, error)
await asyncio.sleep(self._reconnect_sleep)
async def _process_metric(self):
async def _process_metric(self) -> None:
try:
metric = await asyncio.wait_for(self.queue.get(),
self._wait_timeout)
@ -432,4 +446,5 @@ class Processor:
# it has queued metrics to send
metric = b''
assert self.protocol is not None # AFAICT, this cannot happen
self.protocol.send(metric)

View file

@ -2,6 +2,7 @@ import contextlib
import os
import socket
import time
import typing
from tornado import web
@ -71,7 +72,9 @@ class Application(web.Application):
processor quickly responds to connection faults.
"""
def __init__(self, *args, **settings):
statsd_connector: typing.Optional[statsd.Connector]
def __init__(self, *args: typing.Any, **settings: typing.Any):
statsd_settings = settings.setdefault('statsd', {})
statsd_settings.setdefault('host', os.environ.get('STATSD_HOST'))
statsd_settings.setdefault('port',
@ -98,7 +101,7 @@ class Application(web.Application):
self.settings['statsd']['port'] = int(self.settings['statsd']['port'])
self.statsd_connector = None
async def start_statsd(self):
async def start_statsd(self) -> None:
"""Start the connector during startup.
Call this method during application startup to enable the statsd
@ -130,7 +133,7 @@ class Application(web.Application):
self.statsd_connector = statsd.Connector(**kwargs)
await self.statsd_connector.start()
async def stop_statsd(self):
async def stop_statsd(self) -> None:
"""Stop the connector during shutdown.
If the connector was started, then this method will gracefully
@ -145,20 +148,20 @@ class Application(web.Application):
class RequestHandler(web.RequestHandler):
"""Mix this into your handler to send metrics to a statsd server."""
statsd_connector: statsd.Connector
statsd_connector: typing.Optional[statsd.Connector]
def initialize(self, **kwargs):
def initialize(self, **kwargs: typing.Any) -> None:
super().initialize(**kwargs)
self.application: Application
self.statsd_connector = self.application.statsd_connector
def __build_path(self, *path):
def __build_path(self, *path: typing.Any) -> str:
full_path = '.'.join(str(c) for c in path)
if self.settings.get('statsd', {}).get('prefix', ''):
return f'{self.settings["statsd"]["prefix"]}.{full_path}'
return full_path
def record_timing(self, secs: float, *path):
def record_timing(self, secs: float, *path: typing.Any) -> None:
"""Record the duration.
:param secs: number of seconds to record
@ -169,7 +172,7 @@ class RequestHandler(web.RequestHandler):
self.statsd_connector.timing(self.__build_path('timers', *path),
secs)
def increase_counter(self, *path, amount: int = 1):
def increase_counter(self, *path: typing.Any, amount: int = 1) -> None:
"""Adjust a counter.
:param path: path of the counter to adjust
@ -182,7 +185,8 @@ class RequestHandler(web.RequestHandler):
amount)
@contextlib.contextmanager
def execution_timer(self, *path):
def execution_timer(
self, *path: typing.Any) -> typing.Generator[None, None, None]:
"""Record the execution duration of a block of code.
:param path: path to record the duration as
@ -194,7 +198,7 @@ class RequestHandler(web.RequestHandler):
finally:
self.record_timing(time.time() - start, *path)
def on_finish(self):
def on_finish(self) -> None:
"""Extended to record the request time as a duration.
This method extends :meth:`tornado.web.RequestHandler.on_finish`

View file

@ -20,7 +20,7 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
self.running = asyncio.Event()
self.client_connected = asyncio.Semaphore(value=0)
self.message_received = asyncio.Semaphore(value=0)
self.transports: list[asyncio.Transport] = []
self.transports: list[asyncio.BaseTransport] = []
self._buffer = io.BytesIO()
@ -34,7 +34,8 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
self.port,
reuse_port=True)
self.server = server
listening_sock = server.sockets[0]
listening_sock = typing.cast(list[socket.socket],
server.sockets)[0]
self.host, self.port = listening_sock.getsockname()
self.running.set()
try:
@ -62,7 +63,8 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
self.running.clear()
def close(self):
self.server.close()
if self.server is not None:
self.server.close()
for connected_client in self.transports:
connected_client.close()
self.transports.clear()
@ -74,7 +76,7 @@ class StatsdServer(asyncio.DatagramProtocol, asyncio.Protocol):
while self.running.is_set():
await asyncio.sleep(0.1)
def connection_made(self, transport: asyncio.Transport):
def connection_made(self, transport: asyncio.BaseTransport):
self.client_connected.release()
self.connections_made += 1
self.transports.append(transport)

View file

@ -2,6 +2,7 @@ import asyncio
import logging
import socket
import time
import typing
import asynctest
@ -110,7 +111,7 @@ class ProcessorTests(ProcessorTestCase):
def test_that_processor_fails_when_host_is_none(self):
with self.assertRaises(RuntimeError) as context:
statsd.Processor(host=None, port=12345)
statsd.Processor(host=None, port=12345) # type: ignore[arg-type]
self.assertIn('host', str(context.exception))
async def test_starting_and_stopping_without_connecting(self):
@ -129,7 +130,7 @@ class ProcessorTests(ProcessorTestCase):
await self.wait_for(processor.running.wait())
with self.assertLogs(processor.logger, level=logging.ERROR) as cm:
processor.queue.put_nowait('not-bytes')
processor.queue.put_nowait('not-bytes') # type: ignore[arg-type]
while processor.queue.qsize() > 0:
await asyncio.sleep(0.1)
@ -189,15 +190,16 @@ class TCPProcessingTests(ProcessorTestCase):
async def test_socket_closure_while_sending(self):
state = {'first_time': True}
real_transport_write = self.processor.protocol.transport.write
protocol = typing.cast(statsd.TCPProtocol, self.processor.protocol)
real_transport_write = protocol.transport.write
def fake_transport_write(buffer):
def fake_transport_write(data):
if state['first_time']:
self.processor.protocol.transport.close()
state['first_time'] = False
return real_transport_write(buffer)
return real_transport_write(data)
self.processor.protocol.transport.write = fake_transport_write
protocol.transport.write = fake_transport_write
self.processor.queue.put_nowait(b'counter:1|c')
await self.wait_for(self.statsd_server.message_received.acquire())
@ -265,8 +267,8 @@ class ConnectorTests(ProcessorTestCase):
await super().asyncTearDown()
def assert_metrics_equal(self, recvd: bytes, path, value, type_code):
recvd = recvd.decode('utf-8')
recvd_path, _, rest = recvd.partition(':')
decoded = recvd.decode('utf-8')
recvd_path, _, rest = decoded.partition(':')
recvd_value, _, recvd_code = rest.partition('|')
self.assertEqual(path, recvd_path, 'metric path mismatch')
self.assertEqual(recvd_value, str(value), 'metric value mismatch')

View file

@ -257,19 +257,19 @@ class RequestHandlerTests(AsyncTestCaseWithTimeout, testing.AsyncHTTPTestCase):
timeout_remaining -= (time.time() - start)
def parse_metric(self, metric_line: bytes) -> ParsedMetric:
metric_line = metric_line.decode()
path, _, rest = metric_line.partition(':')
decoded = metric_line.decode()
path, _, rest = decoded.partition(':')
value, _, type_code = rest.partition('|')
try:
value = float(value)
parsed_value = float(value)
except ValueError:
self.fail(f'value of {path} is not a number: value={value!r}')
return path, value, type_code
return path, parsed_value, type_code
def find_metric(self, needle: str) -> ParsedMetric:
needle = needle.encode()
encoded = needle.encode()
for line in self.statsd_server.metrics:
if needle in line:
if encoded in line:
return self.parse_metric(line)
self.fail(f'failed to find metric containing {needle!r}')