mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-11-13 03:00:19 +00:00
Add support for SRV based configuration
Supports both postgresql+srv and aws+srv where aws+srv supports the ECS service discovery Route53 SRV record lookup behavior.
This commit is contained in:
parent
10b98cba10
commit
3e36210c91
6 changed files with 210 additions and 4 deletions
|
@ -44,6 +44,15 @@ The following table details the environment variable configuration options:
|
|||
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
|
||||
+---------------------------------+--------------------------------------------------+-----------+
|
||||
|
||||
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
||||
performed and the lowest priority record with the highest weight will be selected
|
||||
for connecting to Postgres.
|
||||
|
||||
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
||||
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
||||
performed using the correct format for ECS service discovery. The lowest priority
|
||||
record with the highest weight will be selected for connecting to Postgres.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
- `aiopg <https://aioboto3.readthedocs.io/en/latest/>`_
|
||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
1.1.0
|
||||
1.2.0
|
||||
|
|
|
@ -25,3 +25,12 @@ details the configuration options and their defaults.
|
|||
+---------------------------------+--------------------------------------------------+-----------+
|
||||
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
|
||||
+---------------------------------+--------------------------------------------------+-----------+
|
||||
|
||||
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
||||
performed and the lowest priority record with the highest weight will be selected
|
||||
for connecting to Postgres.
|
||||
|
||||
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
||||
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
||||
performed using the correct format for ECS service discovery. The lowest priority
|
||||
record with the highest weight will be selected for connecting to Postgres.
|
||||
|
|
|
@ -34,6 +34,7 @@ keywords =
|
|||
include_package_data = True
|
||||
install_requires =
|
||||
aiopg>=1.0.0,<2
|
||||
aiodns
|
||||
sprockets.http>=2.1.1,<3
|
||||
tornado>=6,<7
|
||||
py_modules =
|
||||
|
@ -42,6 +43,7 @@ zip_safe = true
|
|||
|
||||
[options.extras_require]
|
||||
testing =
|
||||
asynctest
|
||||
coverage
|
||||
flake8
|
||||
flake8-comprehensions
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import time
|
||||
import typing
|
||||
from distutils import util
|
||||
from urllib import parse
|
||||
|
||||
import aiodns
|
||||
import aiopg
|
||||
import psycopg2
|
||||
import pycares
|
||||
from aiodns import error as aiodns_error
|
||||
from aiopg import pool
|
||||
from psycopg2 import errors, extras
|
||||
from tornado import ioloop, web
|
||||
|
@ -365,9 +370,11 @@ class ApplicationMixin:
|
|||
}
|
||||
|
||||
"""
|
||||
LOGGER.debug('Querying postgres status')
|
||||
query_error = asyncio.Event()
|
||||
|
||||
def on_error(_metric_name, _exc) -> None:
|
||||
def on_error(metric_name, exc) -> None:
|
||||
LOGGER.debug('Query Error for %r: %r', metric_name, exc)
|
||||
query_error.set()
|
||||
return None
|
||||
|
||||
|
@ -383,6 +390,31 @@ class ApplicationMixin:
|
|||
'pool_free': self._postgres_pool.freesize
|
||||
}
|
||||
|
||||
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
|
||||
if parsed.scheme.startswith('postgresql+'):
|
||||
host_parts = parsed.hostname.split('.')
|
||||
records = await self._resolve_srv(
|
||||
'_{}._{}.{}'.format(
|
||||
host_parts[0], 'postgresql', '.'.join(host_parts[1:])))
|
||||
elif parsed.scheme.startswith('aws+'):
|
||||
records = await self._resolve_srv(parsed.hostname)
|
||||
else:
|
||||
raise RuntimeError('Unsupported URI Scheme: {}'.format(
|
||||
parsed.scheme))
|
||||
|
||||
if not records:
|
||||
raise RuntimeError('No SRV records found')
|
||||
|
||||
if parsed.username and not parsed.password:
|
||||
return 'postgresql://{}@{}:{}{}'.format(
|
||||
parsed.username, records[0].host, records[0].port, parsed.path)
|
||||
elif parsed.username and parsed.password:
|
||||
return 'postgresql://{}:{}@{}:{}{}'.format(
|
||||
parsed.username, parsed.password,
|
||||
records[0].host, records[0].port, parsed.path)
|
||||
return 'postgresql://{}:{}{}'.format(
|
||||
records[0].host, records[0].port, parsed.path)
|
||||
|
||||
async def _postgres_setup(self,
|
||||
_app: web.Application,
|
||||
loop: ioloop.IOLoop) -> None:
|
||||
|
@ -395,8 +427,20 @@ class ApplicationMixin:
|
|||
if 'POSTGRES_URL' not in os.environ:
|
||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||
return self.stop(loop)
|
||||
|
||||
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
||||
if parsed.scheme.endswith('+srv'):
|
||||
try:
|
||||
url = await self._postgres_url_from_srv(parsed)
|
||||
except RuntimeError as error:
|
||||
LOGGER.critical(str(error))
|
||||
return self.stop(loop)
|
||||
else:
|
||||
url = os.environ['POSTGRES_URL']
|
||||
|
||||
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
|
||||
self._postgres_pool = pool.Pool(
|
||||
os.environ['POSTGRES_URL'],
|
||||
url,
|
||||
maxsize=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_MAX_POOL_SIZE',
|
||||
|
@ -429,6 +473,7 @@ class ApplicationMixin:
|
|||
psycopg2.Error) as error: # pragma: nocover
|
||||
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
||||
error)
|
||||
LOGGER.debug('Connected to Postgres')
|
||||
|
||||
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
||||
"""Shutdown the Postgres connections and wait for them to close.
|
||||
|
@ -441,6 +486,18 @@ class ApplicationMixin:
|
|||
self._postgres_pool.close()
|
||||
await self._postgres_pool.wait_closed()
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_srv(hostname: str) \
|
||||
-> typing.List[pycares.ares_query_srv_result]:
|
||||
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
|
||||
try:
|
||||
records = await resolver.query(hostname, 'SRV')
|
||||
except aiodns_error.DNSError as error:
|
||||
LOGGER.critical('DNS resolution error: %s', error)
|
||||
raise RuntimeError(str(error))
|
||||
s = sorted(records, key=operator.attrgetter('weight'), reverse=True)
|
||||
return sorted(s, key=operator.attrgetter('priority'))
|
||||
|
||||
|
||||
class RequestHandlerMixin:
|
||||
"""
|
||||
|
|
131
tests.py
131
tests.py
|
@ -1,12 +1,16 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
import unittest
|
||||
import uuid
|
||||
from unittest import mock
|
||||
from urllib import parse
|
||||
|
||||
import asynctest
|
||||
import psycopg2
|
||||
import pycares
|
||||
from asynctest import mock
|
||||
from psycopg2 import errors
|
||||
from sprockets.http import app, testing
|
||||
from tornado import ioloop, web
|
||||
|
@ -433,3 +437,128 @@ class MissingURLTestCase(unittest.TestCase):
|
|||
obj.start(io_loop)
|
||||
io_loop.start()
|
||||
obj.stop.assert_called_once()
|
||||
|
||||
|
||||
SRV = collections.namedtuple(
|
||||
'SRV', ['host', 'port', 'priority', 'weight', 'ttl'])
|
||||
|
||||
|
||||
class SRVTestCase(asynctest.TestCase):
|
||||
|
||||
async def test_srv_result(self):
|
||||
obj = Application()
|
||||
result = await obj._resolve_srv('_xmpp-server._tcp.google.com')
|
||||
self.assertIsInstance(result[0], pycares.ares_query_srv_result)
|
||||
self.assertGreater(result[0].ttl, 0)
|
||||
|
||||
async def test_srv_error(self):
|
||||
obj = Application()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await obj._resolve_srv('_postgresql._tcp.foo')
|
||||
|
||||
@mock.patch('sprockets.http.app.Application.stop')
|
||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||
async def test_aws_srv_parsing(self, critical, stop):
|
||||
obj = Application()
|
||||
loop = ioloop.IOLoop.current()
|
||||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||
resolve_srv.return_value = []
|
||||
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
|
||||
await obj._postgres_setup(obj, loop)
|
||||
stop.assert_called_once_with(loop)
|
||||
critical.assert_called_once_with('No SRV records found')
|
||||
|
||||
@mock.patch('sprockets.http.app.Application.stop')
|
||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||
async def test_postgres_srv_parsing(self, critical, stop):
|
||||
obj = Application()
|
||||
loop = ioloop.IOLoop.current()
|
||||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||
resolve_srv.return_value = []
|
||||
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
|
||||
await obj._postgres_setup(obj, loop)
|
||||
stop.assert_called_once_with(loop)
|
||||
critical.assert_called_once_with('No SRV records found')
|
||||
|
||||
@mock.patch('sprockets.http.app.Application.stop')
|
||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||
async def test_unsupported_srv_uri(self, critical, stop):
|
||||
obj = Application()
|
||||
loop = ioloop.IOLoop.current()
|
||||
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
|
||||
await obj._postgres_setup(obj, loop)
|
||||
stop.assert_called_once_with(loop)
|
||||
critical.assert_called_once_with(
|
||||
'Unsupported URI Scheme: postgres+srv')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_aws_url_from_srv_variation_1(self, query):
|
||||
obj = Application()
|
||||
parsed = parse.urlparse('aws+srv://foo@bar.baz/qux')
|
||||
future = asyncio.Future()
|
||||
future.set_result([
|
||||
SRV('foo2', 5432, 2, 0, 32),
|
||||
SRV('foo1', 5432, 1, 1, 32),
|
||||
SRV('foo3', 6432, 1, 0, 32)
|
||||
])
|
||||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('bar.baz', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_postgresql_url_from_srv_variation_1(self, query):
|
||||
obj = Application()
|
||||
parsed = parse.urlparse('postgresql+srv://foo@bar.baz/qux')
|
||||
future = asyncio.Future()
|
||||
future.set_result([
|
||||
SRV('foo2', 5432, 2, 0, 32),
|
||||
SRV('foo1', 5432, 1, 0, 32)
|
||||
])
|
||||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_postgresql_url_from_srv_variation_2(self, query):
|
||||
obj = Application()
|
||||
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
|
||||
future = asyncio.Future()
|
||||
future.set_result([
|
||||
SRV('foo2', 5432, 2, 0, 32),
|
||||
SRV('foo1', 5432, 1, 0, 32)
|
||||
])
|
||||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_postgresql_url_from_srv_variation_3(self, query):
|
||||
obj = Application()
|
||||
parsed = parse.urlparse('postgresql+srv://foo.bar/baz')
|
||||
future = asyncio.Future()
|
||||
future.set_result([
|
||||
SRV('foo2', 5432, 2, 0, 32),
|
||||
SRV('foo1', 5432, 1, 0, 32),
|
||||
SRV('foo3', 5432, 1, 10, 32),
|
||||
])
|
||||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo3:5432/baz')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_resolve_srv_sorted(self, query):
|
||||
obj = Application()
|
||||
result = [
|
||||
SRV('foo2', 5432, 2, 0, 32),
|
||||
SRV('foo1', 5432, 1, 1, 32),
|
||||
SRV('foo3', 6432, 1, 0, 32)
|
||||
]
|
||||
future = asyncio.Future()
|
||||
future.set_result(result)
|
||||
query.return_value = future
|
||||
records = await obj._resolve_srv('foo')
|
||||
self.assertListEqual(records, [result[1], result[2], result[0]])
|
||||
|
|
Loading…
Reference in a new issue