mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-09-21 10:10:57 +00:00
Add support for SRV based configuration
Supports both postgresql+srv and aws+srv where aws+srv supports the ECS service discovery Route53 SRV record lookup behavior.
This commit is contained in:
parent
10b98cba10
commit
3e36210c91
6 changed files with 210 additions and 4 deletions
|
@ -44,6 +44,15 @@ The following table details the environment variable configuration options:
|
||||||
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
|
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
|
||||||
+---------------------------------+--------------------------------------------------+-----------+
|
+---------------------------------+--------------------------------------------------+-----------+
|
||||||
|
|
||||||
|
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
||||||
|
performed and the lowest priority record with the highest weight will be selected
|
||||||
|
for connecting to Postgres.
|
||||||
|
|
||||||
|
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
||||||
|
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
||||||
|
performed using the correct format for ECS service discovery. The lowest priority
|
||||||
|
record with the highest weight will be selected for connecting to Postgres.
|
||||||
|
|
||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
- `aiopg <https://aioboto3.readthedocs.io/en/latest/>`_
|
- `aiopg <https://aioboto3.readthedocs.io/en/latest/>`_
|
||||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
1.1.0
|
1.2.0
|
||||||
|
|
|
@ -25,3 +25,12 @@ details the configuration options and their defaults.
|
||||||
+---------------------------------+--------------------------------------------------+-----------+
|
+---------------------------------+--------------------------------------------------+-----------+
|
||||||
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
|
| ``POSTGRES_UUID`` | Enable UUID support in the client. | ``TRUE`` |
|
||||||
+---------------------------------+--------------------------------------------------+-----------+
|
+---------------------------------+--------------------------------------------------+-----------+
|
||||||
|
|
||||||
|
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
||||||
|
performed and the lowest priority record with the highest weight will be selected
|
||||||
|
for connecting to Postgres.
|
||||||
|
|
||||||
|
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
||||||
|
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
||||||
|
performed using the correct format for ECS service discovery. The lowest priority
|
||||||
|
record with the highest weight will be selected for connecting to Postgres.
|
||||||
|
|
|
@ -34,6 +34,7 @@ keywords =
|
||||||
include_package_data = True
|
include_package_data = True
|
||||||
install_requires =
|
install_requires =
|
||||||
aiopg>=1.0.0,<2
|
aiopg>=1.0.0,<2
|
||||||
|
aiodns
|
||||||
sprockets.http>=2.1.1,<3
|
sprockets.http>=2.1.1,<3
|
||||||
tornado>=6,<7
|
tornado>=6,<7
|
||||||
py_modules =
|
py_modules =
|
||||||
|
@ -42,6 +43,7 @@ zip_safe = true
|
||||||
|
|
||||||
[options.extras_require]
|
[options.extras_require]
|
||||||
testing =
|
testing =
|
||||||
|
asynctest
|
||||||
coverage
|
coverage
|
||||||
flake8
|
flake8
|
||||||
flake8-comprehensions
|
flake8-comprehensions
|
||||||
|
|
|
@ -1,13 +1,18 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from distutils import util
|
from distutils import util
|
||||||
|
from urllib import parse
|
||||||
|
|
||||||
|
import aiodns
|
||||||
import aiopg
|
import aiopg
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
import pycares
|
||||||
|
from aiodns import error as aiodns_error
|
||||||
from aiopg import pool
|
from aiopg import pool
|
||||||
from psycopg2 import errors, extras
|
from psycopg2 import errors, extras
|
||||||
from tornado import ioloop, web
|
from tornado import ioloop, web
|
||||||
|
@ -365,9 +370,11 @@ class ApplicationMixin:
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
LOGGER.debug('Querying postgres status')
|
||||||
query_error = asyncio.Event()
|
query_error = asyncio.Event()
|
||||||
|
|
||||||
def on_error(_metric_name, _exc) -> None:
|
def on_error(metric_name, exc) -> None:
|
||||||
|
LOGGER.debug('Query Error for %r: %r', metric_name, exc)
|
||||||
query_error.set()
|
query_error.set()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -383,6 +390,31 @@ class ApplicationMixin:
|
||||||
'pool_free': self._postgres_pool.freesize
|
'pool_free': self._postgres_pool.freesize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
|
||||||
|
if parsed.scheme.startswith('postgresql+'):
|
||||||
|
host_parts = parsed.hostname.split('.')
|
||||||
|
records = await self._resolve_srv(
|
||||||
|
'_{}._{}.{}'.format(
|
||||||
|
host_parts[0], 'postgresql', '.'.join(host_parts[1:])))
|
||||||
|
elif parsed.scheme.startswith('aws+'):
|
||||||
|
records = await self._resolve_srv(parsed.hostname)
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Unsupported URI Scheme: {}'.format(
|
||||||
|
parsed.scheme))
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
raise RuntimeError('No SRV records found')
|
||||||
|
|
||||||
|
if parsed.username and not parsed.password:
|
||||||
|
return 'postgresql://{}@{}:{}{}'.format(
|
||||||
|
parsed.username, records[0].host, records[0].port, parsed.path)
|
||||||
|
elif parsed.username and parsed.password:
|
||||||
|
return 'postgresql://{}:{}@{}:{}{}'.format(
|
||||||
|
parsed.username, parsed.password,
|
||||||
|
records[0].host, records[0].port, parsed.path)
|
||||||
|
return 'postgresql://{}:{}{}'.format(
|
||||||
|
records[0].host, records[0].port, parsed.path)
|
||||||
|
|
||||||
async def _postgres_setup(self,
|
async def _postgres_setup(self,
|
||||||
_app: web.Application,
|
_app: web.Application,
|
||||||
loop: ioloop.IOLoop) -> None:
|
loop: ioloop.IOLoop) -> None:
|
||||||
|
@ -395,8 +427,20 @@ class ApplicationMixin:
|
||||||
if 'POSTGRES_URL' not in os.environ:
|
if 'POSTGRES_URL' not in os.environ:
|
||||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||||
return self.stop(loop)
|
return self.stop(loop)
|
||||||
|
|
||||||
|
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
||||||
|
if parsed.scheme.endswith('+srv'):
|
||||||
|
try:
|
||||||
|
url = await self._postgres_url_from_srv(parsed)
|
||||||
|
except RuntimeError as error:
|
||||||
|
LOGGER.critical(str(error))
|
||||||
|
return self.stop(loop)
|
||||||
|
else:
|
||||||
|
url = os.environ['POSTGRES_URL']
|
||||||
|
|
||||||
|
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
|
||||||
self._postgres_pool = pool.Pool(
|
self._postgres_pool = pool.Pool(
|
||||||
os.environ['POSTGRES_URL'],
|
url,
|
||||||
maxsize=int(
|
maxsize=int(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
'POSTGRES_MAX_POOL_SIZE',
|
'POSTGRES_MAX_POOL_SIZE',
|
||||||
|
@ -429,6 +473,7 @@ class ApplicationMixin:
|
||||||
psycopg2.Error) as error: # pragma: nocover
|
psycopg2.Error) as error: # pragma: nocover
|
||||||
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
||||||
error)
|
error)
|
||||||
|
LOGGER.debug('Connected to Postgres')
|
||||||
|
|
||||||
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
||||||
"""Shutdown the Postgres connections and wait for them to close.
|
"""Shutdown the Postgres connections and wait for them to close.
|
||||||
|
@ -441,6 +486,18 @@ class ApplicationMixin:
|
||||||
self._postgres_pool.close()
|
self._postgres_pool.close()
|
||||||
await self._postgres_pool.wait_closed()
|
await self._postgres_pool.wait_closed()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _resolve_srv(hostname: str) \
|
||||||
|
-> typing.List[pycares.ares_query_srv_result]:
|
||||||
|
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
|
||||||
|
try:
|
||||||
|
records = await resolver.query(hostname, 'SRV')
|
||||||
|
except aiodns_error.DNSError as error:
|
||||||
|
LOGGER.critical('DNS resolution error: %s', error)
|
||||||
|
raise RuntimeError(str(error))
|
||||||
|
s = sorted(records, key=operator.attrgetter('weight'), reverse=True)
|
||||||
|
return sorted(s, key=operator.attrgetter('priority'))
|
||||||
|
|
||||||
|
|
||||||
class RequestHandlerMixin:
|
class RequestHandlerMixin:
|
||||||
"""
|
"""
|
||||||
|
|
131
tests.py
131
tests.py
|
@ -1,12 +1,16 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from unittest import mock
|
from urllib import parse
|
||||||
|
|
||||||
|
import asynctest
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
import pycares
|
||||||
|
from asynctest import mock
|
||||||
from psycopg2 import errors
|
from psycopg2 import errors
|
||||||
from sprockets.http import app, testing
|
from sprockets.http import app, testing
|
||||||
from tornado import ioloop, web
|
from tornado import ioloop, web
|
||||||
|
@ -433,3 +437,128 @@ class MissingURLTestCase(unittest.TestCase):
|
||||||
obj.start(io_loop)
|
obj.start(io_loop)
|
||||||
io_loop.start()
|
io_loop.start()
|
||||||
obj.stop.assert_called_once()
|
obj.stop.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
SRV = collections.namedtuple(
|
||||||
|
'SRV', ['host', 'port', 'priority', 'weight', 'ttl'])
|
||||||
|
|
||||||
|
|
||||||
|
class SRVTestCase(asynctest.TestCase):
|
||||||
|
|
||||||
|
async def test_srv_result(self):
|
||||||
|
obj = Application()
|
||||||
|
result = await obj._resolve_srv('_xmpp-server._tcp.google.com')
|
||||||
|
self.assertIsInstance(result[0], pycares.ares_query_srv_result)
|
||||||
|
self.assertGreater(result[0].ttl, 0)
|
||||||
|
|
||||||
|
async def test_srv_error(self):
|
||||||
|
obj = Application()
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
await obj._resolve_srv('_postgresql._tcp.foo')
|
||||||
|
|
||||||
|
@mock.patch('sprockets.http.app.Application.stop')
|
||||||
|
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||||
|
async def test_aws_srv_parsing(self, critical, stop):
|
||||||
|
obj = Application()
|
||||||
|
loop = ioloop.IOLoop.current()
|
||||||
|
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||||
|
resolve_srv.return_value = []
|
||||||
|
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
|
||||||
|
await obj._postgres_setup(obj, loop)
|
||||||
|
stop.assert_called_once_with(loop)
|
||||||
|
critical.assert_called_once_with('No SRV records found')
|
||||||
|
|
||||||
|
@mock.patch('sprockets.http.app.Application.stop')
|
||||||
|
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||||
|
async def test_postgres_srv_parsing(self, critical, stop):
|
||||||
|
obj = Application()
|
||||||
|
loop = ioloop.IOLoop.current()
|
||||||
|
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||||
|
resolve_srv.return_value = []
|
||||||
|
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
|
||||||
|
await obj._postgres_setup(obj, loop)
|
||||||
|
stop.assert_called_once_with(loop)
|
||||||
|
critical.assert_called_once_with('No SRV records found')
|
||||||
|
|
||||||
|
@mock.patch('sprockets.http.app.Application.stop')
|
||||||
|
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||||
|
async def test_unsupported_srv_uri(self, critical, stop):
|
||||||
|
obj = Application()
|
||||||
|
loop = ioloop.IOLoop.current()
|
||||||
|
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
|
||||||
|
await obj._postgres_setup(obj, loop)
|
||||||
|
stop.assert_called_once_with(loop)
|
||||||
|
critical.assert_called_once_with(
|
||||||
|
'Unsupported URI Scheme: postgres+srv')
|
||||||
|
|
||||||
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
|
async def test_aws_url_from_srv_variation_1(self, query):
|
||||||
|
obj = Application()
|
||||||
|
parsed = parse.urlparse('aws+srv://foo@bar.baz/qux')
|
||||||
|
future = asyncio.Future()
|
||||||
|
future.set_result([
|
||||||
|
SRV('foo2', 5432, 2, 0, 32),
|
||||||
|
SRV('foo1', 5432, 1, 1, 32),
|
||||||
|
SRV('foo3', 6432, 1, 0, 32)
|
||||||
|
])
|
||||||
|
query.return_value = future
|
||||||
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
|
query.assert_called_once_with('bar.baz', 'SRV')
|
||||||
|
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
||||||
|
|
||||||
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
|
async def test_postgresql_url_from_srv_variation_1(self, query):
|
||||||
|
obj = Application()
|
||||||
|
parsed = parse.urlparse('postgresql+srv://foo@bar.baz/qux')
|
||||||
|
future = asyncio.Future()
|
||||||
|
future.set_result([
|
||||||
|
SRV('foo2', 5432, 2, 0, 32),
|
||||||
|
SRV('foo1', 5432, 1, 0, 32)
|
||||||
|
])
|
||||||
|
query.return_value = future
|
||||||
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
|
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
|
||||||
|
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
||||||
|
|
||||||
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
|
async def test_postgresql_url_from_srv_variation_2(self, query):
|
||||||
|
obj = Application()
|
||||||
|
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
|
||||||
|
future = asyncio.Future()
|
||||||
|
future.set_result([
|
||||||
|
SRV('foo2', 5432, 2, 0, 32),
|
||||||
|
SRV('foo1', 5432, 1, 0, 32)
|
||||||
|
])
|
||||||
|
query.return_value = future
|
||||||
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
|
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
|
||||||
|
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie')
|
||||||
|
|
||||||
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
|
async def test_postgresql_url_from_srv_variation_3(self, query):
|
||||||
|
obj = Application()
|
||||||
|
parsed = parse.urlparse('postgresql+srv://foo.bar/baz')
|
||||||
|
future = asyncio.Future()
|
||||||
|
future.set_result([
|
||||||
|
SRV('foo2', 5432, 2, 0, 32),
|
||||||
|
SRV('foo1', 5432, 1, 0, 32),
|
||||||
|
SRV('foo3', 5432, 1, 10, 32),
|
||||||
|
])
|
||||||
|
query.return_value = future
|
||||||
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
|
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
|
||||||
|
self.assertEqual(url, 'postgresql://foo3:5432/baz')
|
||||||
|
|
||||||
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
|
async def test_resolve_srv_sorted(self, query):
|
||||||
|
obj = Application()
|
||||||
|
result = [
|
||||||
|
SRV('foo2', 5432, 2, 0, 32),
|
||||||
|
SRV('foo1', 5432, 1, 1, 32),
|
||||||
|
SRV('foo3', 6432, 1, 0, 32)
|
||||||
|
]
|
||||||
|
future = asyncio.Future()
|
||||||
|
future.set_result(result)
|
||||||
|
query.return_value = future
|
||||||
|
records = await obj._resolve_srv('foo')
|
||||||
|
self.assertListEqual(records, [result[1], result[2], result[0]])
|
||||||
|
|
Loading…
Reference in a new issue