Add support for SRV based configuration

Supports both postgresql+srv and aws+srv where aws+srv supports the ECS service discovery Route53 SRV record lookup behavior.
This commit is contained in:
Gavin M. Roy 2020-06-02 11:26:07 -04:00
parent 10b98cba10
commit 3e36210c91
6 changed files with 210 additions and 4 deletions

View file

@ -44,6 +44,15 @@ The following table details the environment variable configuration options:
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` | | ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
+---------------------------------+--------------------------------------------------+-----------+ +---------------------------------+--------------------------------------------------+-----------+
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
performed and the lowest priority record with the highest weight will be selected
for connecting to Postgres.
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
performed using the correct format for ECS service discovery. The lowest priority
record with the highest weight will be selected for connecting to Postgres.
Requirements Requirements
------------ ------------
- `aiopg <https://aioboto3.readthedocs.io/en/latest/>`_ - `aiopg <https://aioboto3.readthedocs.io/en/latest/>`_

View file

@ -1 +1 @@
1.1.0 1.2.0

View file

@ -25,3 +25,12 @@ details the configuration options and their defaults.
+---------------------------------+--------------------------------------------------+-----------+ +---------------------------------+--------------------------------------------------+-----------+
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` | | ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
+---------------------------------+--------------------------------------------------+-----------+ +---------------------------------+--------------------------------------------------+-----------+
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
performed and the lowest priority record with the highest weight will be selected
for connecting to Postgres.
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
performed using the correct format for ECS service discovery. The lowest priority
record with the highest weight will be selected for connecting to Postgres.

View file

@ -34,6 +34,7 @@ keywords =
include_package_data = True include_package_data = True
install_requires = install_requires =
aiopg>=1.0.0,<2 aiopg>=1.0.0,<2
aiodns
sprockets.http>=2.1.1,<3 sprockets.http>=2.1.1,<3
tornado>=6,<7 tornado>=6,<7
py_modules = py_modules =
@ -42,6 +43,7 @@ zip_safe = true
[options.extras_require] [options.extras_require]
testing = testing =
asynctest
coverage coverage
flake8 flake8
flake8-comprehensions flake8-comprehensions

View file

@ -1,13 +1,18 @@
import asyncio import asyncio
import contextlib import contextlib
import logging import logging
import operator
import os import os
import time import time
import typing import typing
from distutils import util from distutils import util
from urllib import parse
import aiodns
import aiopg import aiopg
import psycopg2 import psycopg2
import pycares
from aiodns import error as aiodns_error
from aiopg import pool from aiopg import pool
from psycopg2 import errors, extras from psycopg2 import errors, extras
from tornado import ioloop, web from tornado import ioloop, web
@ -365,9 +370,11 @@ class ApplicationMixin:
} }
""" """
LOGGER.debug('Querying postgres status')
query_error = asyncio.Event() query_error = asyncio.Event()
def on_error(_metric_name, _exc) -> None: def on_error(metric_name, exc) -> None:
LOGGER.debug('Query Error for %r: %r', metric_name, exc)
query_error.set() query_error.set()
return None return None
@ -383,6 +390,31 @@ class ApplicationMixin:
'pool_free': self._postgres_pool.freesize 'pool_free': self._postgres_pool.freesize
} }
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
if parsed.scheme.startswith('postgresql+'):
host_parts = parsed.hostname.split('.')
records = await self._resolve_srv(
'_{}._{}.{}'.format(
host_parts[0], 'postgresql', '.'.join(host_parts[1:])))
elif parsed.scheme.startswith('aws+'):
records = await self._resolve_srv(parsed.hostname)
else:
raise RuntimeError('Unsupported URI Scheme: {}'.format(
parsed.scheme))
if not records:
raise RuntimeError('No SRV records found')
if parsed.username and not parsed.password:
return 'postgresql://{}@{}:{}{}'.format(
parsed.username, records[0].host, records[0].port, parsed.path)
elif parsed.username and parsed.password:
return 'postgresql://{}:{}@{}:{}{}'.format(
parsed.username, parsed.password,
records[0].host, records[0].port, parsed.path)
return 'postgresql://{}:{}{}'.format(
records[0].host, records[0].port, parsed.path)
async def _postgres_setup(self, async def _postgres_setup(self,
_app: web.Application, _app: web.Application,
loop: ioloop.IOLoop) -> None: loop: ioloop.IOLoop) -> None:
@ -395,8 +427,20 @@ class ApplicationMixin:
if 'POSTGRES_URL' not in os.environ: if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable') LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop) return self.stop(loop)
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
if parsed.scheme.endswith('+srv'):
try:
url = await self._postgres_url_from_srv(parsed)
except RuntimeError as error:
LOGGER.critical(str(error))
return self.stop(loop)
else:
url = os.environ['POSTGRES_URL']
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
self._postgres_pool = pool.Pool( self._postgres_pool = pool.Pool(
os.environ['POSTGRES_URL'], url,
maxsize=int( maxsize=int(
os.environ.get( os.environ.get(
'POSTGRES_MAX_POOL_SIZE', 'POSTGRES_MAX_POOL_SIZE',
@ -429,6 +473,7 @@ class ApplicationMixin:
psycopg2.Error) as error: # pragma: nocover psycopg2.Error) as error: # pragma: nocover
LOGGER.warning('Error connecting to PostgreSQL on startup: %s', LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
error) error)
LOGGER.debug('Connected to Postgres')
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None: async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
"""Shutdown the Postgres connections and wait for them to close. """Shutdown the Postgres connections and wait for them to close.
@ -441,6 +486,18 @@ class ApplicationMixin:
self._postgres_pool.close() self._postgres_pool.close()
await self._postgres_pool.wait_closed() await self._postgres_pool.wait_closed()
@staticmethod
async def _resolve_srv(hostname: str) \
-> typing.List[pycares.ares_query_srv_result]:
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
try:
records = await resolver.query(hostname, 'SRV')
except aiodns_error.DNSError as error:
LOGGER.critical('DNS resolution error: %s', error)
raise RuntimeError(str(error))
s = sorted(records, key=operator.attrgetter('weight'), reverse=True)
return sorted(s, key=operator.attrgetter('priority'))
class RequestHandlerMixin: class RequestHandlerMixin:
""" """

131
tests.py
View file

@ -1,12 +1,16 @@
import asyncio import asyncio
import collections
import json import json
import os import os
import typing import typing
import unittest import unittest
import uuid import uuid
from unittest import mock from urllib import parse
import asynctest
import psycopg2 import psycopg2
import pycares
from asynctest import mock
from psycopg2 import errors from psycopg2 import errors
from sprockets.http import app, testing from sprockets.http import app, testing
from tornado import ioloop, web from tornado import ioloop, web
@ -433,3 +437,128 @@ class MissingURLTestCase(unittest.TestCase):
obj.start(io_loop) obj.start(io_loop)
io_loop.start() io_loop.start()
obj.stop.assert_called_once() obj.stop.assert_called_once()
SRV = collections.namedtuple(
'SRV', ['host', 'port', 'priority', 'weight', 'ttl'])
class SRVTestCase(asynctest.TestCase):
async def test_srv_result(self):
obj = Application()
result = await obj._resolve_srv('_xmpp-server._tcp.google.com')
self.assertIsInstance(result[0], pycares.ares_query_srv_result)
self.assertGreater(result[0].ttl, 0)
async def test_srv_error(self):
obj = Application()
with self.assertRaises(RuntimeError):
await obj._resolve_srv('_postgresql._tcp.foo')
@mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical')
async def test_aws_srv_parsing(self, critical, stop):
obj = Application()
loop = ioloop.IOLoop.current()
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop)
stop.assert_called_once_with(loop)
critical.assert_called_once_with('No SRV records found')
@mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical')
async def test_postgres_srv_parsing(self, critical, stop):
obj = Application()
loop = ioloop.IOLoop.current()
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop)
stop.assert_called_once_with(loop)
critical.assert_called_once_with('No SRV records found')
@mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical')
async def test_unsupported_srv_uri(self, critical, stop):
obj = Application()
loop = ioloop.IOLoop.current()
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop)
stop.assert_called_once_with(loop)
critical.assert_called_once_with(
'Unsupported URI Scheme: postgres+srv')
@mock.patch('aiodns.DNSResolver.query')
async def test_aws_url_from_srv_variation_1(self, query):
obj = Application()
parsed = parse.urlparse('aws+srv://foo@bar.baz/qux')
future = asyncio.Future()
future.set_result([
SRV('foo2', 5432, 2, 0, 32),
SRV('foo1', 5432, 1, 1, 32),
SRV('foo3', 6432, 1, 0, 32)
])
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('bar.baz', 'SRV')
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
@mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_1(self, query):
obj = Application()
parsed = parse.urlparse('postgresql+srv://foo@bar.baz/qux')
future = asyncio.Future()
future.set_result([
SRV('foo2', 5432, 2, 0, 32),
SRV('foo1', 5432, 1, 0, 32)
])
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
@mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_2(self, query):
obj = Application()
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
future = asyncio.Future()
future.set_result([
SRV('foo2', 5432, 2, 0, 32),
SRV('foo1', 5432, 1, 0, 32)
])
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie')
@mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_3(self, query):
obj = Application()
parsed = parse.urlparse('postgresql+srv://foo.bar/baz')
future = asyncio.Future()
future.set_result([
SRV('foo2', 5432, 2, 0, 32),
SRV('foo1', 5432, 1, 0, 32),
SRV('foo3', 5432, 1, 10, 32),
])
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
self.assertEqual(url, 'postgresql://foo3:5432/baz')
@mock.patch('aiodns.DNSResolver.query')
async def test_resolve_srv_sorted(self, query):
obj = Application()
result = [
SRV('foo2', 5432, 2, 0, 32),
SRV('foo1', 5432, 1, 1, 32),
SRV('foo3', 6432, 1, 0, 32)
]
future = asyncio.Future()
future.set_result(result)
query.return_value = future
records = await obj._resolve_srv('foo')
self.assertListEqual(records, [result[1], result[2], result[0]])