Add reconfiguration for SRV based connections

- Attempt to add graceful reconfiguration for SRV based connections. On connection failure, the pool will be closed and reopened after fetching to get new SRV records.
- When using SRV, use all return hosts in the PostgreSQL URL
- If multiple requests hit a disconnect error, the reconnect logic will allow the first one in to reconnect and the others to wait
This commit is contained in:
Gavin M. Roy 2020-07-07 16:17:56 -04:00
parent 9c448c11e3
commit ecc070e974
4 changed files with 205 additions and 103 deletions

View file

@ -1 +1 @@
1.2.0
1.3.0

View file

@ -27,10 +27,10 @@ details the configuration options and their defaults.
+---------------------------------+--------------------------------------------------+-----------+
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
performed and the lowest priority record with the highest weight will be selected
for connecting to Postgres.
performed and the URL will be constructed containing all host and port combinations
in priority and weighted order, utilizing `libpq's supoprt <https://www.postgresql.org/docs/12/libpq-connect.html>`_
for multiple hosts in a URL.
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
performed using the correct format for ECS service discovery. The lowest priority
record with the highest weight will be selected for connecting to Postgres.
performed using the correct format for ECS service discovery.

View file

@ -276,10 +276,9 @@ class ConnectionException(Exception):
class ApplicationMixin:
"""
:class:`sprockets.http.app.Application` / :class:`tornado.web.Application`
mixin for handling the connection to Postgres and exporting functions for
querying the database, getting the status, and proving a cursor.
""":class:`sprockets.http.app.Application` mixin for handling the
connection to Postgres and exporting functions for querying the database,
getting the status, and proving a cursor.
Automatically creates and shuts down :class:`aiopg.Pool` on startup
and shutdown by installing `on_start` and `shutdown` callbacks into the
@ -291,7 +290,10 @@ class ApplicationMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._postgres_pool: typing.Optional[pool.Pool] = None
self.runner_callbacks['on_start'].append(self._postgres_setup)
self._postgres_connected = asyncio.Event()
self._postgres_reconnect = asyncio.Lock()
self._postgres_srv: bool = False
self.runner_callbacks['on_start'].append(self._postgres_on_start)
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
@contextlib.asynccontextmanager
@ -299,7 +301,8 @@ class ApplicationMixin:
on_error: typing.Callable,
on_duration: typing.Optional[
typing.Callable] = None,
timeout: Timeout = None) \
timeout: Timeout = None,
_attempt: int = 1) \
-> typing.AsyncContextManager[PostgresConnector]:
"""Asynchronous :ref:`context-manager <python:typecontextmanager>`
that returns a :class:`~sprockets_postgres.PostgresConnector` instance
@ -338,7 +341,19 @@ class ApplicationMixin:
yield PostgresConnector(
cursor, on_error, on_duration, timeout)
except (asyncio.TimeoutError, psycopg2.Error) as err:
exc = on_error('postgres_connector', ConnectionException(str(err)))
message = str(err)
if isinstance(err, psycopg2.OperationalError) and _attempt == 1:
LOGGER.critical('Disconnected from Postgres: %s', err)
if not self._postgres_reconnect.locked():
async with self._postgres_reconnect:
if await self._postgres_connect():
async with self.postgres_connector(
on_error, on_duration, timeout,
_attempt + 1) as connector:
yield connector
return
message = 'disconnected'
exc = on_error('postgres_connector', ConnectionException(message))
if exc:
raise exc
else: # postgres_status.on_error does not return an exception
@ -370,6 +385,13 @@ class ApplicationMixin:
}
"""
if not self._postgres_connected.is_set():
return {
'available': False,
'pool_size': 0,
'pool_free': 0
}
LOGGER.debug('Querying postgres status')
query_error = asyncio.Event()
@ -390,56 +412,27 @@ class ApplicationMixin:
'pool_free': self._postgres_pool.freesize
}
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
if parsed.scheme.startswith('postgresql+'):
host_parts = parsed.hostname.split('.')
records = await self._resolve_srv(
'_{}._{}.{}'.format(
host_parts[0], 'postgresql', '.'.join(host_parts[1:])))
elif parsed.scheme.startswith('aws+'):
records = await self._resolve_srv(parsed.hostname)
else:
raise RuntimeError('Unsupported URI Scheme: {}'.format(
parsed.scheme))
if not records:
raise RuntimeError('No SRV records found')
if parsed.username and not parsed.password:
return 'postgresql://{}@{}:{}{}'.format(
parsed.username, records[0].host, records[0].port, parsed.path)
elif parsed.username and parsed.password:
return 'postgresql://{}:{}@{}:{}{}'.format(
parsed.username, parsed.password,
records[0].host, records[0].port, parsed.path)
return 'postgresql://{}:{}{}'.format(
records[0].host, records[0].port, parsed.path)
async def _postgres_setup(self,
_app: web.Application,
loop: ioloop.IOLoop) -> None:
"""Setup the Postgres pool of connections and log if there is an error.
This is invoked by the :class:`sprockets.http.app.Application` on start
callback mechanism.
"""
if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
async def _postgres_connect(self) -> bool:
"""Setup the Postgres pool of connections"""
self._postgres_connected.clear()
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
if parsed.scheme.endswith('+srv'):
self._postgres_srv = True
try:
url = await self._postgres_url_from_srv(parsed)
except RuntimeError as error:
LOGGER.critical(str(error))
return self.stop(loop)
return False
else:
url = os.environ['POSTGRES_URL']
if self._postgres_pool:
self._postgres_pool.close()
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
self._postgres_pool = pool.Pool(
try:
self._postgres_pool = await pool.Pool.from_pool_fill(
url,
maxsize=int(
os.environ.get(
@ -466,14 +459,30 @@ class ApplicationMixin:
os.environ.get(
'POSTGRES_CONNECTION_TTL',
DEFAULT_POSTGRES_CONNECTION_TTL)))
try:
async with self._postgres_pool._cond:
await self._postgres_pool._fill_free_pool(False)
except (psycopg2.OperationalError,
psycopg2.Error) as error: # pragma: nocover
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
error)
return False
self._postgres_connected.set()
LOGGER.debug('Connected to Postgres')
return True
async def _postgres_on_start(self,
_app: web.Application,
loop: ioloop.IOLoop):
"""Invoked as a startup step for the application
This is invoked by the :class:`sprockets.http.app.Application` on start
callback mechanism.
"""
if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
if not await self._postgres_connect():
LOGGER.critical('PostgreSQL failed to connect, shutting down')
return self.stop(loop)
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
"""Shutdown the Postgres connections and wait for them to close.
@ -486,6 +495,32 @@ class ApplicationMixin:
self._postgres_pool.close()
await self._postgres_pool.wait_closed()
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
if parsed.scheme.startswith('postgresql+'):
host_parts = parsed.hostname.split('.')
records = await self._resolve_srv(
'_{}._{}.{}'.format(
host_parts[0], 'postgresql', '.'.join(host_parts[1:])))
elif parsed.scheme.startswith('aws+'):
records = await self._resolve_srv(parsed.hostname)
else:
raise RuntimeError('Unsupported URI Scheme: {}'.format(
parsed.scheme))
if not records:
raise RuntimeError('No SRV records found')
netloc = []
if parsed.username and not parsed.password:
netloc.append('{}@'.format(parsed.username))
elif parsed.username and parsed.password:
netloc.append('{}:{}@'.format(parsed.username, parsed.password))
netloc.append(','.join([
'{}:{}'.format(r.host, r.port) for r in records]))
return parse.urlunparse(
('postgresql', ''.join(netloc), parsed.path,
parsed.params, parsed.query, ''))
@staticmethod
async def _resolve_srv(hostname: str) \
-> typing.List[pycares.ares_query_srv_result]:

View file

@ -1,5 +1,6 @@
import asyncio
import collections
import contextlib
import json
import os
import typing
@ -7,17 +8,21 @@ import unittest
import uuid
from urllib import parse
import aiopg
import asynctest
import psycopg2
import pycares
from asynctest import mock
from psycopg2 import errors
from sprockets.http import app, testing
from tornado import ioloop, web
from tornado import ioloop, testing as ttesting, web
import sprockets_postgres
test_postgres_cursor_oer_invocation = 0
class RequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
"""Base RequestHandler for test endpoints"""
@ -223,6 +228,11 @@ class Application(sprockets_postgres.ApplicationMixin,
class TestCase(testing.SprocketsHttpTestCase):
def setUp(self):
super().setUp()
asyncio.get_event_loop().run_until_complete(
self.app._postgres_connected.wait())
@classmethod
def setUpClass(cls):
with open('build/test-environment') as f:
@ -267,6 +277,12 @@ class RequestHandlerMixinTestCase(TestCase):
self.assertEqual(response.code, 503)
self.assertFalse(json.loads(response.body)['available'])
def test_postgres_status_not_connected(self):
self.app._postgres_connected.clear()
response = self.fetch('/status')
self.assertEqual(response.code, 503)
self.assertFalse(json.loads(response.body)['available'])
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_status_error(self, execute):
execute.side_effect = asyncio.TimeoutError()
@ -384,6 +400,55 @@ class RequestHandlerMixinTestCase(TestCase):
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503)
def test_postgres_cursor_operational_error_reconnects(self):
original = aiopg.connection.Connection.cursor
@contextlib.asynccontextmanager
async def mock_cursor(self, name=None, cursor_factory=None,
scrollable=None, withhold=False, timeout=None):
global test_postgres_cursor_oer_invocation
test_postgres_cursor_oer_invocation += 1
if test_postgres_cursor_oer_invocation == 1:
raise psycopg2.OperationalError()
async with original(self, name, cursor_factory, scrollable,
withhold, timeout) as value:
yield value
aiopg.connection.Connection.cursor = mock_cursor
with mock.patch.object(self.app, '_postgres_connect') as connect:
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 200)
self.assertEqual(json.loads(response.body)['value'], '1')
connect.assert_called_once()
aiopg.connection.Connection.cursor = original
@mock.patch('aiopg.connection.Connection.cursor')
def test_postgres_cursor_raises(self, cursor):
cursor.side_effect = psycopg2.OperationalError()
with mock.patch.object(self.app, '_postgres_connect') as connect:
connect.return_value = False
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503)
connect.assert_called_once()
@ttesting.gen_test()
@mock.patch('aiopg.connection.Connection.cursor')
async def test_postgres_cursor_failure_concurrency(self, cursor):
cursor.side_effect = psycopg2.OperationalError()
def on_error(*args):
return RuntimeError
async def invoke_cursor():
async with self.app.postgres_connector(on_error) as connector:
await connector.execute('SELECT 1')
with self.assertRaises(RuntimeError):
await asyncio.gather(invoke_cursor(), invoke_cursor())
class TransactionTestCase(TestCase):
@ -464,9 +529,9 @@ class SRVTestCase(asynctest.TestCase):
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop)
await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop)
critical.assert_called_once_with('No SRV records found')
critical.assert_any_call('No SRV records found')
@mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical')
@ -476,9 +541,9 @@ class SRVTestCase(asynctest.TestCase):
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop)
await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop)
critical.assert_called_once_with('No SRV records found')
critical.assert_any_call('No SRV records found')
@mock.patch('sprockets.http.app.Application.stop')
@mock.patch('sprockets_postgres.LOGGER.critical')
@ -486,9 +551,9 @@ class SRVTestCase(asynctest.TestCase):
obj = Application()
loop = ioloop.IOLoop.current()
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
await obj._postgres_setup(obj, loop)
await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop)
critical.assert_called_once_with(
critical.assert_any_call(
'Unsupported URI Scheme: postgres+srv')
@mock.patch('aiodns.DNSResolver.query')
@ -504,7 +569,8 @@ class SRVTestCase(asynctest.TestCase):
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('bar.baz', 'SRV')
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
self.assertEqual(
url, 'postgresql://foo@foo1:5432,foo3:6432,foo2:5432/qux')
@mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_1(self, query):
@ -518,7 +584,7 @@ class SRVTestCase(asynctest.TestCase):
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
self.assertEqual(url, 'postgresql://foo@foo1:5432,foo2:5432/qux')
@mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_2(self, query):
@ -526,13 +592,14 @@ class SRVTestCase(asynctest.TestCase):
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
future = asyncio.Future()
future.set_result([
SRV('foo2', 5432, 2, 0, 32),
SRV('foo1', 5432, 1, 0, 32)
SRV('foo2', 5432, 1, 0, 32),
SRV('foo1', 5432, 2, 0, 32)
])
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie')
self.assertEqual(
url, 'postgresql://foo:bar@foo2:5432,foo1:5432/corgie')
@mock.patch('aiodns.DNSResolver.query')
async def test_postgresql_url_from_srv_variation_3(self, query):
@ -547,7 +614,7 @@ class SRVTestCase(asynctest.TestCase):
query.return_value = future
url = await obj._postgres_url_from_srv(parsed)
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
self.assertEqual(url, 'postgresql://foo3:5432/baz')
self.assertEqual(url, 'postgresql://foo3:5432,foo1:5432,foo2:5432/baz')
@mock.patch('aiodns.DNSResolver.query')
async def test_resolve_srv_sorted(self, query):