Add reconfiguration for SRV based connections

- Attempt to add graceful reconfiguration for SRV based connections. On connection failure, the pool will be closed and reopened after fetching to get new SRV records.
- When using SRV, use all return hosts in the PostgreSQL URL
- If multiple requests hit a disconnect error, the reconnect logic will allow the first one in to reconnect and the others to wait
This commit is contained in:
Gavin M. Roy 2020-07-07 16:17:56 -04:00
parent 9c448c11e3
commit ecc070e974
4 changed files with 205 additions and 103 deletions

View file

@ -1 +1 @@
1.2.0 1.3.0

View file

@ -27,10 +27,10 @@ details the configuration options and their defaults.
+---------------------------------+--------------------------------------------------+-----------+ +---------------------------------+--------------------------------------------------+-----------+
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
performed and the lowest priority record with the highest weight will be selected performed and the URL will be constructed containing all host and port combinations
for connecting to Postgres. in priority and weighted order, utilizing `libpq's supoprt <https://www.postgresql.org/docs/12/libpq-connect.html>`_
for multiple hosts in a URL.
AWS's ECS service discovery does not follow the SRV standard, but creates SRV AWS's ECS service discovery does not follow the SRV standard, but creates SRV
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
performed using the correct format for ECS service discovery. The lowest priority performed using the correct format for ECS service discovery.
record with the highest weight will be selected for connecting to Postgres.

View file

@ -276,10 +276,9 @@ class ConnectionException(Exception):
class ApplicationMixin: class ApplicationMixin:
""" """:class:`sprockets.http.app.Application` mixin for handling the
:class:`sprockets.http.app.Application` / :class:`tornado.web.Application` connection to Postgres and exporting functions for querying the database,
mixin for handling the connection to Postgres and exporting functions for getting the status, and proving a cursor.
querying the database, getting the status, and proving a cursor.
Automatically creates and shuts down :class:`aiopg.Pool` on startup Automatically creates and shuts down :class:`aiopg.Pool` on startup
and shutdown by installing `on_start` and `shutdown` callbacks into the and shutdown by installing `on_start` and `shutdown` callbacks into the
@ -291,7 +290,10 @@ class ApplicationMixin:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._postgres_pool: typing.Optional[pool.Pool] = None self._postgres_pool: typing.Optional[pool.Pool] = None
self.runner_callbacks['on_start'].append(self._postgres_setup) self._postgres_connected = asyncio.Event()
self._postgres_reconnect = asyncio.Lock()
self._postgres_srv: bool = False
self.runner_callbacks['on_start'].append(self._postgres_on_start)
self.runner_callbacks['shutdown'].append(self._postgres_shutdown) self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
@ -299,7 +301,8 @@ class ApplicationMixin:
on_error: typing.Callable, on_error: typing.Callable,
on_duration: typing.Optional[ on_duration: typing.Optional[
typing.Callable] = None, typing.Callable] = None,
timeout: Timeout = None) \ timeout: Timeout = None,
_attempt: int = 1) \
-> typing.AsyncContextManager[PostgresConnector]: -> typing.AsyncContextManager[PostgresConnector]:
"""Asynchronous :ref:`context-manager <python:typecontextmanager>` """Asynchronous :ref:`context-manager <python:typecontextmanager>`
that returns a :class:`~sprockets_postgres.PostgresConnector` instance that returns a :class:`~sprockets_postgres.PostgresConnector` instance
@ -338,7 +341,19 @@ class ApplicationMixin:
yield PostgresConnector( yield PostgresConnector(
cursor, on_error, on_duration, timeout) cursor, on_error, on_duration, timeout)
except (asyncio.TimeoutError, psycopg2.Error) as err: except (asyncio.TimeoutError, psycopg2.Error) as err:
exc = on_error('postgres_connector', ConnectionException(str(err))) message = str(err)
if isinstance(err, psycopg2.OperationalError) and _attempt == 1:
LOGGER.critical('Disconnected from Postgres: %s', err)
if not self._postgres_reconnect.locked():
async with self._postgres_reconnect:
if await self._postgres_connect():
async with self.postgres_connector(
on_error, on_duration, timeout,
_attempt + 1) as connector:
yield connector
return
message = 'disconnected'
exc = on_error('postgres_connector', ConnectionException(message))
if exc: if exc:
raise exc raise exc
else: # postgres_status.on_error does not return an exception else: # postgres_status.on_error does not return an exception
@ -370,6 +385,13 @@ class ApplicationMixin:
} }
""" """
if not self._postgres_connected.is_set():
return {
'available': False,
'pool_size': 0,
'pool_free': 0
}
LOGGER.debug('Querying postgres status') LOGGER.debug('Querying postgres status')
query_error = asyncio.Event() query_error = asyncio.Event()
@ -390,6 +412,89 @@ class ApplicationMixin:
'pool_free': self._postgres_pool.freesize 'pool_free': self._postgres_pool.freesize
} }
async def _postgres_connect(self) -> bool:
"""Setup the Postgres pool of connections"""
self._postgres_connected.clear()
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
if parsed.scheme.endswith('+srv'):
self._postgres_srv = True
try:
url = await self._postgres_url_from_srv(parsed)
except RuntimeError as error:
LOGGER.critical(str(error))
return False
else:
url = os.environ['POSTGRES_URL']
if self._postgres_pool:
self._postgres_pool.close()
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
try:
self._postgres_pool = await pool.Pool.from_pool_fill(
url,
maxsize=int(
os.environ.get(
'POSTGRES_MAX_POOL_SIZE',
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
minsize=int(
os.environ.get(
'POSTGRES_MIN_POOL_SIZE',
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
timeout=int(
os.environ.get(
'POSTGRES_CONNECT_TIMEOUT',
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
enable_hstore=util.strtobool(
os.environ.get(
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
enable_json=util.strtobool(
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
enable_uuid=util.strtobool(
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
echo=False,
on_connect=None,
pool_recycle=int(
os.environ.get(
'POSTGRES_CONNECTION_TTL',
DEFAULT_POSTGRES_CONNECTION_TTL)))
except (psycopg2.OperationalError,
psycopg2.Error) as error: # pragma: nocover
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
error)
return False
self._postgres_connected.set()
LOGGER.debug('Connected to Postgres')
return True
async def _postgres_on_start(self,
_app: web.Application,
loop: ioloop.IOLoop):
"""Invoked as a startup step for the application
This is invoked by the :class:`sprockets.http.app.Application` on start
callback mechanism.
"""
if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
if not await self._postgres_connect():
LOGGER.critical('PostgreSQL failed to connect, shutting down')
return self.stop(loop)
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
"""Shutdown the Postgres connections and wait for them to close.
This is invoked by the :class:`sprockets.http.app.Application` shutdown
callback mechanism.
"""
if self._postgres_pool is not None:
self._postgres_pool.close()
await self._postgres_pool.wait_closed()
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str: async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
if parsed.scheme.startswith('postgresql+'): if parsed.scheme.startswith('postgresql+'):
host_parts = parsed.hostname.split('.') host_parts = parsed.hostname.split('.')
@ -405,86 +510,16 @@ class ApplicationMixin:
if not records: if not records:
raise RuntimeError('No SRV records found') raise RuntimeError('No SRV records found')
netloc = []
if parsed.username and not parsed.password: if parsed.username and not parsed.password:
return 'postgresql://{}@{}:{}{}'.format( netloc.append('{}@'.format(parsed.username))
parsed.username, records[0].host, records[0].port, parsed.path)
elif parsed.username and parsed.password: elif parsed.username and parsed.password:
return 'postgresql://{}:{}@{}:{}{}'.format( netloc.append('{}:{}@'.format(parsed.username, parsed.password))
parsed.username, parsed.password, netloc.append(','.join([
records[0].host, records[0].port, parsed.path) '{}:{}'.format(r.host, r.port) for r in records]))
return 'postgresql://{}:{}{}'.format( return parse.urlunparse(
records[0].host, records[0].port, parsed.path) ('postgresql', ''.join(netloc), parsed.path,
parsed.params, parsed.query, ''))
async def _postgres_setup(self,
_app: web.Application,
loop: ioloop.IOLoop) -> None:
"""Setup the Postgres pool of connections and log if there is an error.
This is invoked by the :class:`sprockets.http.app.Application` on start
callback mechanism.
"""
if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
if parsed.scheme.endswith('+srv'):
try:
url = await self._postgres_url_from_srv(parsed)
except RuntimeError as error:
LOGGER.critical(str(error))
return self.stop(loop)
else:
url = os.environ['POSTGRES_URL']
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
self._postgres_pool = pool.Pool(
url,
maxsize=int(
os.environ.get(
'POSTGRES_MAX_POOL_SIZE',
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
minsize=int(
os.environ.get(
'POSTGRES_MIN_POOL_SIZE',
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
timeout=int(
os.environ.get(
'POSTGRES_CONNECT_TIMEOUT',
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
enable_hstore=util.strtobool(
os.environ.get(
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
enable_json=util.strtobool(
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
enable_uuid=util.strtobool(
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
echo=False,
on_connect=None,
pool_recycle=int(
os.environ.get(
'POSTGRES_CONNECTION_TTL',
DEFAULT_POSTGRES_CONNECTION_TTL)))
try:
async with self._postgres_pool._cond:
await self._postgres_pool._fill_free_pool(False)
except (psycopg2.OperationalError,
psycopg2.Error) as error: # pragma: nocover
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
error)
LOGGER.debug('Connected to Postgres')
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
"""Shutdown the Postgres connections and wait for them to close.
This is invoked by the :class:`sprockets.http.app.Application` shutdown
callback mechanism.
"""
if self._postgres_pool is not None:
self._postgres_pool.close()
await self._postgres_pool.wait_closed()
@staticmethod @staticmethod
async def _resolve_srv(hostname: str) \ async def _resolve_srv(hostname: str) \

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import collections import collections
import contextlib
import json import json
import os import os
import typing import typing
@ -7,17 +8,21 @@ import unittest
import uuid import uuid
from urllib import parse from urllib import parse
import aiopg
import asynctest import asynctest
import psycopg2 import psycopg2
import pycares import pycares
from asynctest import mock from asynctest import mock
from psycopg2 import errors from psycopg2 import errors
from sprockets.http import app, testing from sprockets.http import app, testing
from tornado import ioloop, web from tornado import ioloop, testing as ttesting, web
import sprockets_postgres import sprockets_postgres
test_postgres_cursor_oer_invocation = 0
class RequestHandler(sprockets_postgres.RequestHandlerMixin, class RequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler): web.RequestHandler):
"""Base RequestHandler for test endpoints""" """Base RequestHandler for test endpoints"""
@ -223,6 +228,11 @@ class Application(sprockets_postgres.ApplicationMixin,
class TestCase(testing.SprocketsHttpTestCase): class TestCase(testing.SprocketsHttpTestCase):
def setUp(self):
super().setUp()
asyncio.get_event_loop().run_until_complete(
self.app._postgres_connected.wait())
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
with open('build/test-environment') as f: with open('build/test-environment') as f:
@ -267,6 +277,12 @@ class RequestHandlerMixinTestCase(TestCase):
self.assertEqual(response.code, 503) self.assertEqual(response.code, 503)
self.assertFalse(json.loads(response.body)['available']) self.assertFalse(json.loads(response.body)['available'])
def test_postgres_status_not_connected(self):
self.app._postgres_connected.clear()
response = self.fetch('/status')
self.assertEqual(response.code, 503)
self.assertFalse(json.loads(response.body)['available'])
@mock.patch('aiopg.cursor.Cursor.execute') @mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_status_error(self, execute): def test_postgres_status_error(self, execute):
execute.side_effect = asyncio.TimeoutError() execute.side_effect = asyncio.TimeoutError()
@ -384,6 +400,55 @@ class RequestHandlerMixinTestCase(TestCase):
response = self.fetch('/execute?value=1') response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503) self.assertEqual(response.code, 503)
def test_postgres_cursor_operational_error_reconnects(self):
original = aiopg.connection.Connection.cursor
@contextlib.asynccontextmanager
async def mock_cursor(self, name=None, cursor_factory=None,
scrollable=None, withhold=False, timeout=None):
global test_postgres_cursor_oer_invocation
test_postgres_cursor_oer_invocation += 1
if test_postgres_cursor_oer_invocation == 1:
raise psycopg2.OperationalError()
async with original(self, name, cursor_factory, scrollable,
withhold, timeout) as value:
yield value
aiopg.connection.Connection.cursor = mock_cursor
with mock.patch.object(self.app, '_postgres_connect') as connect:
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 200)
self.assertEqual(json.loads(response.body)['value'], '1')
connect.assert_called_once()
aiopg.connection.Connection.cursor = original
@mock.patch('aiopg.connection.Connection.cursor')
def test_postgres_cursor_raises(self, cursor):
cursor.side_effect = psycopg2.OperationalError()
with mock.patch.object(self.app, '_postgres_connect') as connect:
connect.return_value = False
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503)
connect.assert_called_once()
@ttesting.gen_test()
@mock.patch('aiopg.connection.Connection.cursor')
async def test_postgres_cursor_failure_concurrency(self, cursor):
cursor.side_effect = psycopg2.OperationalError()
def on_error(*args):
return RuntimeError
async def invoke_cursor():
async with self.app.postgres_connector(on_error) as connector:
await connector.execute('SELECT 1')
with self.assertRaises(RuntimeError):
await asyncio.gather(invoke_cursor(), invoke_cursor())
class TransactionTestCase(TestCase): class TransactionTestCase(TestCase):
@ -464,9 +529,9 @@ class SRVTestCase(asynctest.TestCase):
with mock.patch.object(obj, '_resolve_srv') as resolve_srv: with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = [] resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz' os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop) await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop) stop.assert_called_once_with(loop)
critical.assert_called_once_with('No SRV records found') critical.assert_any_call('No SRV records found')
@mock.patch('sprockets.http.app.Application.stop') @mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical') @mock.patch('sprockets_postgres.LOGGER.critical')
@ -476,9 +541,9 @@ class SRVTestCase(asynctest.TestCase):
with mock.patch.object(obj, '_resolve_srv') as resolve_srv: with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = [] resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz' os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop) await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop) stop.assert_called_once_with(loop)
critical.assert_called_once_with('No SRV records found') critical.assert_any_call('No SRV records found')
@mock.patch('sprockets.http.app.Application.stop') @mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical') @mock.patch('sprockets_postgres.LOGGER.critical')
@ -486,9 +551,9 @@ class SRVTestCase(asynctest.TestCase):
obj = Application() obj = Application()
loop = ioloop.IOLoop.current() loop = ioloop.IOLoop.current()
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz' os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop) await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop) stop.assert_called_once_with(loop)
critical.assert_called_once_with( critical.assert_any_call(
'Unsupported URI Scheme: postgres+srv') 'Unsupported URI Scheme: postgres+srv')
@mock.patch('aiodns.DNSResolver.query') @mock.patch('aiodns.DNSResolver.query')
@ -504,7 +569,8 @@ class SRVTestCase(asynctest.TestCase):
query.return_value = future query.return_value = future
url = await obj._postgres_url_from_srv(parsed) url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('bar.baz', 'SRV') query.assert_called_once_with('bar.baz', 'SRV')
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux') self.assertEqual(
url, 'postgresql://foo@foo1:5432,foo3:6432,foo2:5432/qux')
@mock.patch('aiodns.DNSResolver.query') @mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_1(self, query): async def test_postgresql_url_from_srv_variation_1(self, query):
@ -518,7 +584,7 @@ class SRVTestCase(asynctest.TestCase):
query.return_value = future query.return_value = future
url = await obj._postgres_url_from_srv(parsed) url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_bar._postgresql.baz', 'SRV') query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux') self.assertEqual(url, 'postgresql://foo@foo1:5432,foo2:5432/qux')
@mock.patch('aiodns.DNSResolver.query') @mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_2(self, query): async def test_postgresql_url_from_srv_variation_2(self, query):
@ -526,13 +592,14 @@ class SRVTestCase(asynctest.TestCase):
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie') parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
future = asyncio.Future() future = asyncio.Future()
future.set_result([ future.set_result([
SRV('foo2', 5432, 2, 0, 32), SRV('foo2', 5432, 1, 0, 32),
SRV('foo1', 5432, 1, 0, 32) SRV('foo1', 5432, 2, 0, 32)
]) ])
query.return_value = future query.return_value = future
url = await obj._postgres_url_from_srv(parsed) url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_baz._postgresql.qux', 'SRV') query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie') self.assertEqual(
url, 'postgresql://foo:bar@foo2:5432,foo1:5432/corgie')
@mock.patch('aiodns.DNSResolver.query') @mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_3(self, query): async def test_postgresql_url_from_srv_variation_3(self, query):
@ -547,7 +614,7 @@ class SRVTestCase(asynctest.TestCase):
query.return_value = future query.return_value = future
url = await obj._postgres_url_from_srv(parsed) url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_foo._postgresql.bar', 'SRV') query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
self.assertEqual(url, 'postgresql://foo3:5432/baz') self.assertEqual(url, 'postgresql://foo3:5432,foo1:5432,foo2:5432/baz')
@mock.patch('aiodns.DNSResolver.query') @mock.patch('aiodns.DNSResolver.query')
async def test_resolve_srv_sorted(self, query): async def test_resolve_srv_sorted(self, query):