diff --git a/README.rst b/README.rst index f031691..1b1bedd 100644 --- a/README.rst +++ b/README.rst @@ -44,6 +44,15 @@ The following table details the environment variable configuration options: | ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` | +---------------------------------+--------------------------------------------------+-----------+ +If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be +performed and the lowest priority record with the highest weight will be selected +for connecting to Postgres. + +AWS's ECS service discovery does not follow the SRV standard, but creates SRV +records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be +performed using the correct format for ECS service discovery. The lowest priority +record with the highest weight will be selected for connecting to Postgres. + Requirements ------------ - `aiopg `_ diff --git a/VERSION b/VERSION index 9084fa2..26aaba0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.0 +1.2.0 diff --git a/docs/configuration.rst b/docs/configuration.rst index 53a0329..d08c29b 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -25,3 +25,12 @@ details the configuration options and their defaults. +---------------------------------+--------------------------------------------------+-----------+ | ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` | +---------------------------------+--------------------------------------------------+-----------+ + +If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be +performed and the lowest priority record with the highest weight will be selected +for connecting to Postgres. + +AWS's ECS service discovery does not follow the SRV standard, but creates SRV +records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be +performed using the correct format for ECS service discovery. The lowest priority +record with the highest weight will be selected for connecting to Postgres. diff --git a/setup.cfg b/setup.cfg index 8e5e635..0f44aff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ keywords = include_package_data = True install_requires = aiopg>=1.0.0,<2 + aiodns sprockets.http>=2.1.1,<3 tornado>=6,<7 py_modules = @@ -42,6 +43,7 @@ zip_safe = true [options.extras_require] testing = + asynctest coverage flake8 flake8-comprehensions diff --git a/sprockets_postgres.py b/sprockets_postgres.py index c05a93b..fbe3bac 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -1,13 +1,18 @@ import asyncio import contextlib import logging +import operator import os import time import typing from distutils import util +from urllib import parse +import aiodns import aiopg import psycopg2 +import pycares +from aiodns import error as aiodns_error from aiopg import pool from psycopg2 import errors, extras from tornado import ioloop, web @@ -365,9 +370,11 @@ class ApplicationMixin: } """ + LOGGER.debug('Querying postgres status') query_error = asyncio.Event() - def on_error(_metric_name, _exc) -> None: + def on_error(metric_name, exc) -> None: + LOGGER.debug('Query Error for %r: %r', metric_name, exc) query_error.set() return None @@ -383,6 +390,31 @@ class ApplicationMixin: 'pool_free': self._postgres_pool.freesize } + async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str: + if parsed.scheme.startswith('postgresql+'): + host_parts = parsed.hostname.split('.') + records = await self._resolve_srv( + '_{}._{}.{}'.format( + host_parts[0], 'postgresql', '.'.join(host_parts[1:]))) + elif parsed.scheme.startswith('aws+'): + records = await self._resolve_srv(parsed.hostname) + else: + raise RuntimeError('Unsupported URI Scheme: {}'.format( + parsed.scheme)) + + if not records: + raise RuntimeError('No SRV records found') + + if parsed.username and not parsed.password: + return 'postgresql://{}@{}:{}{}'.format( + parsed.username, records[0].host, records[0].port, parsed.path) + elif parsed.username and parsed.password: + return 'postgresql://{}:{}@{}:{}{}'.format( + parsed.username, parsed.password, + records[0].host, records[0].port, parsed.path) + return 'postgresql://{}:{}{}'.format( + records[0].host, records[0].port, parsed.path) + async def _postgres_setup(self, _app: web.Application, loop: ioloop.IOLoop) -> None: @@ -395,8 +427,20 @@ class ApplicationMixin: if 'POSTGRES_URL' not in os.environ: LOGGER.critical('Missing POSTGRES_URL environment variable') return self.stop(loop) + + parsed = parse.urlparse(os.environ['POSTGRES_URL']) + if parsed.scheme.endswith('+srv'): + try: + url = await self._postgres_url_from_srv(parsed) + except RuntimeError as error: + LOGGER.critical(str(error)) + return self.stop(loop) + else: + url = os.environ['POSTGRES_URL'] + + LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL']) self._postgres_pool = pool.Pool( - os.environ['POSTGRES_URL'], + url, maxsize=int( os.environ.get( 'POSTGRES_MAX_POOL_SIZE', @@ -429,6 +473,7 @@ class ApplicationMixin: psycopg2.Error) as error: # pragma: nocover LOGGER.warning('Error connecting to PostgreSQL on startup: %s', error) + LOGGER.debug('Connected to Postgres') async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None: """Shutdown the Postgres connections and wait for them to close. @@ -441,6 +486,18 @@ class ApplicationMixin: self._postgres_pool.close() await self._postgres_pool.wait_closed() + @staticmethod + async def _resolve_srv(hostname: str) \ + -> typing.List[pycares.ares_query_srv_result]: + resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop()) + try: + records = await resolver.query(hostname, 'SRV') + except aiodns_error.DNSError as error: + LOGGER.critical('DNS resolution error: %s', error) + raise RuntimeError(str(error)) + s = sorted(records, key=operator.attrgetter('weight'), reverse=True) + return sorted(s, key=operator.attrgetter('priority')) + class RequestHandlerMixin: """ diff --git a/tests.py b/tests.py index f96d8b5..4c2b210 100644 --- a/tests.py +++ b/tests.py @@ -1,12 +1,16 @@ import asyncio +import collections import json import os import typing import unittest import uuid -from unittest import mock +from urllib import parse +import asynctest import psycopg2 +import pycares +from asynctest import mock from psycopg2 import errors from sprockets.http import app, testing from tornado import ioloop, web @@ -433,3 +437,128 @@ class MissingURLTestCase(unittest.TestCase): obj.start(io_loop) io_loop.start() obj.stop.assert_called_once() + + +SRV = collections.namedtuple( + 'SRV', ['host', 'port', 'priority', 'weight', 'ttl']) + + +class SRVTestCase(asynctest.TestCase): + + async def test_srv_result(self): + obj = Application() + result = await obj._resolve_srv('_xmpp-server._tcp.google.com') + self.assertIsInstance(result[0], pycares.ares_query_srv_result) + self.assertGreater(result[0].ttl, 0) + + async def test_srv_error(self): + obj = Application() + with self.assertRaises(RuntimeError): + await obj._resolve_srv('_postgresql._tcp.foo') + + @mock.patch('sprockets.http.app.Application.stop') + @mock.patch('sprockets_postgres.LOGGER.critical') + async def test_aws_srv_parsing(self, critical, stop): + obj = Application() + loop = ioloop.IOLoop.current() + with mock.patch.object(obj, '_resolve_srv') as resolve_srv: + resolve_srv.return_value = [] + os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz' + await obj._postgres_setup(obj, loop) + stop.assert_called_once_with(loop) + critical.assert_called_once_with('No SRV records found') + + @mock.patch('sprockets.http.app.Application.stop') + @mock.patch('sprockets_postgres.LOGGER.critical') + async def test_postgres_srv_parsing(self, critical, stop): + obj = Application() + loop = ioloop.IOLoop.current() + with mock.patch.object(obj, '_resolve_srv') as resolve_srv: + resolve_srv.return_value = [] + os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz' + await obj._postgres_setup(obj, loop) + stop.assert_called_once_with(loop) + critical.assert_called_once_with('No SRV records found') + + @mock.patch('sprockets.http.app.Application.stop') + @mock.patch('sprockets_postgres.LOGGER.critical') + async def test_unsupported_srv_uri(self, critical, stop): + obj = Application() + loop = ioloop.IOLoop.current() + os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz' + await obj._postgres_setup(obj, loop) + stop.assert_called_once_with(loop) + critical.assert_called_once_with( + 'Unsupported URI Scheme: postgres+srv') + + @mock.patch('aiodns.DNSResolver.query') + async def test_aws_url_from_srv_variation_1(self, query): + obj = Application() + parsed = parse.urlparse('aws+srv://foo@bar.baz/qux') + future = asyncio.Future() + future.set_result([ + SRV('foo2', 5432, 2, 0, 32), + SRV('foo1', 5432, 1, 1, 32), + SRV('foo3', 6432, 1, 0, 32) + ]) + query.return_value = future + url = await obj._postgres_url_from_srv(parsed) + query.assert_called_once_with('bar.baz', 'SRV') + self.assertEqual(url, 'postgresql://foo@foo1:5432/qux') + + @mock.patch('aiodns.DNSResolver.query') + async def test_postgresql_url_from_srv_variation_1(self, query): + obj = Application() + parsed = parse.urlparse('postgresql+srv://foo@bar.baz/qux') + future = asyncio.Future() + future.set_result([ + SRV('foo2', 5432, 2, 0, 32), + SRV('foo1', 5432, 1, 0, 32) + ]) + query.return_value = future + url = await obj._postgres_url_from_srv(parsed) + query.assert_called_once_with('_bar._postgresql.baz', 'SRV') + self.assertEqual(url, 'postgresql://foo@foo1:5432/qux') + + @mock.patch('aiodns.DNSResolver.query') + async def test_postgresql_url_from_srv_variation_2(self, query): + obj = Application() + parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie') + future = asyncio.Future() + future.set_result([ + SRV('foo2', 5432, 2, 0, 32), + SRV('foo1', 5432, 1, 0, 32) + ]) + query.return_value = future + url = await obj._postgres_url_from_srv(parsed) + query.assert_called_once_with('_baz._postgresql.qux', 'SRV') + self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie') + + @mock.patch('aiodns.DNSResolver.query') + async def test_postgresql_url_from_srv_variation_3(self, query): + obj = Application() + parsed = parse.urlparse('postgresql+srv://foo.bar/baz') + future = asyncio.Future() + future.set_result([ + SRV('foo2', 5432, 2, 0, 32), + SRV('foo1', 5432, 1, 0, 32), + SRV('foo3', 5432, 1, 10, 32), + ]) + query.return_value = future + url = await obj._postgres_url_from_srv(parsed) + query.assert_called_once_with('_foo._postgresql.bar', 'SRV') + self.assertEqual(url, 'postgresql://foo3:5432/baz') + + @mock.patch('aiodns.DNSResolver.query') + async def test_resolve_srv_sorted(self, query): + obj = Application() + result = [ + SRV('foo2', 5432, 2, 0, 32), + SRV('foo1', 5432, 1, 1, 32), + SRV('foo3', 6432, 1, 0, 32) + ] + future = asyncio.Future() + future.set_result(result) + query.return_value = future + records = await obj._resolve_srv('foo') + self.assertListEqual(records, [result[1], result[2], result[0]])