mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-12-27 11:17:30 +00:00
Add reconfiguration for SRV based connections
- Attempt to add graceful reconfiguration for SRV based connections. On connection failure, the pool will be closed and reopened after fetching to get new SRV records. - When using SRV, use all return hosts in the PostgreSQL URL - If multiple requests hit a disconnect error, the reconnect logic will allow the first one in to reconnect and the others to wait
This commit is contained in:
parent
9c448c11e3
commit
ecc070e974
4 changed files with 205 additions and 103 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
1.2.0
|
||||
1.3.0
|
||||
|
|
|
@ -27,10 +27,10 @@ details the configuration options and their defaults.
|
|||
+---------------------------------+--------------------------------------------------+-----------+
|
||||
|
||||
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
||||
performed and the lowest priority record with the highest weight will be selected
|
||||
for connecting to Postgres.
|
||||
performed and the URL will be constructed containing all host and port combinations
|
||||
in priority and weighted order, utilizing `libpq's supoprt <https://www.postgresql.org/docs/12/libpq-connect.html>`_
|
||||
for multiple hosts in a URL.
|
||||
|
||||
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
||||
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
||||
performed using the correct format for ECS service discovery. The lowest priority
|
||||
record with the highest weight will be selected for connecting to Postgres.
|
||||
performed using the correct format for ECS service discovery.
|
||||
|
|
|
@ -276,10 +276,9 @@ class ConnectionException(Exception):
|
|||
|
||||
|
||||
class ApplicationMixin:
|
||||
"""
|
||||
:class:`sprockets.http.app.Application` / :class:`tornado.web.Application`
|
||||
mixin for handling the connection to Postgres and exporting functions for
|
||||
querying the database, getting the status, and proving a cursor.
|
||||
""":class:`sprockets.http.app.Application` mixin for handling the
|
||||
connection to Postgres and exporting functions for querying the database,
|
||||
getting the status, and proving a cursor.
|
||||
|
||||
Automatically creates and shuts down :class:`aiopg.Pool` on startup
|
||||
and shutdown by installing `on_start` and `shutdown` callbacks into the
|
||||
|
@ -291,7 +290,10 @@ class ApplicationMixin:
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._postgres_pool: typing.Optional[pool.Pool] = None
|
||||
self.runner_callbacks['on_start'].append(self._postgres_setup)
|
||||
self._postgres_connected = asyncio.Event()
|
||||
self._postgres_reconnect = asyncio.Lock()
|
||||
self._postgres_srv: bool = False
|
||||
self.runner_callbacks['on_start'].append(self._postgres_on_start)
|
||||
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
|
@ -299,7 +301,8 @@ class ApplicationMixin:
|
|||
on_error: typing.Callable,
|
||||
on_duration: typing.Optional[
|
||||
typing.Callable] = None,
|
||||
timeout: Timeout = None) \
|
||||
timeout: Timeout = None,
|
||||
_attempt: int = 1) \
|
||||
-> typing.AsyncContextManager[PostgresConnector]:
|
||||
"""Asynchronous :ref:`context-manager <python:typecontextmanager>`
|
||||
that returns a :class:`~sprockets_postgres.PostgresConnector` instance
|
||||
|
@ -338,7 +341,19 @@ class ApplicationMixin:
|
|||
yield PostgresConnector(
|
||||
cursor, on_error, on_duration, timeout)
|
||||
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||
exc = on_error('postgres_connector', ConnectionException(str(err)))
|
||||
message = str(err)
|
||||
if isinstance(err, psycopg2.OperationalError) and _attempt == 1:
|
||||
LOGGER.critical('Disconnected from Postgres: %s', err)
|
||||
if not self._postgres_reconnect.locked():
|
||||
async with self._postgres_reconnect:
|
||||
if await self._postgres_connect():
|
||||
async with self.postgres_connector(
|
||||
on_error, on_duration, timeout,
|
||||
_attempt + 1) as connector:
|
||||
yield connector
|
||||
return
|
||||
message = 'disconnected'
|
||||
exc = on_error('postgres_connector', ConnectionException(message))
|
||||
if exc:
|
||||
raise exc
|
||||
else: # postgres_status.on_error does not return an exception
|
||||
|
@ -370,6 +385,13 @@ class ApplicationMixin:
|
|||
}
|
||||
|
||||
"""
|
||||
if not self._postgres_connected.is_set():
|
||||
return {
|
||||
'available': False,
|
||||
'pool_size': 0,
|
||||
'pool_free': 0
|
||||
}
|
||||
|
||||
LOGGER.debug('Querying postgres status')
|
||||
query_error = asyncio.Event()
|
||||
|
||||
|
@ -390,6 +412,89 @@ class ApplicationMixin:
|
|||
'pool_free': self._postgres_pool.freesize
|
||||
}
|
||||
|
||||
async def _postgres_connect(self) -> bool:
|
||||
"""Setup the Postgres pool of connections"""
|
||||
self._postgres_connected.clear()
|
||||
|
||||
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
||||
if parsed.scheme.endswith('+srv'):
|
||||
self._postgres_srv = True
|
||||
try:
|
||||
url = await self._postgres_url_from_srv(parsed)
|
||||
except RuntimeError as error:
|
||||
LOGGER.critical(str(error))
|
||||
return False
|
||||
else:
|
||||
url = os.environ['POSTGRES_URL']
|
||||
|
||||
if self._postgres_pool:
|
||||
self._postgres_pool.close()
|
||||
|
||||
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
|
||||
try:
|
||||
self._postgres_pool = await pool.Pool.from_pool_fill(
|
||||
url,
|
||||
maxsize=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_MAX_POOL_SIZE',
|
||||
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
|
||||
minsize=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_MIN_POOL_SIZE',
|
||||
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
|
||||
timeout=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_CONNECT_TIMEOUT',
|
||||
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
|
||||
enable_hstore=util.strtobool(
|
||||
os.environ.get(
|
||||
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
|
||||
enable_json=util.strtobool(
|
||||
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
|
||||
enable_uuid=util.strtobool(
|
||||
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
|
||||
echo=False,
|
||||
on_connect=None,
|
||||
pool_recycle=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_CONNECTION_TTL',
|
||||
DEFAULT_POSTGRES_CONNECTION_TTL)))
|
||||
except (psycopg2.OperationalError,
|
||||
psycopg2.Error) as error: # pragma: nocover
|
||||
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
||||
error)
|
||||
return False
|
||||
self._postgres_connected.set()
|
||||
LOGGER.debug('Connected to Postgres')
|
||||
return True
|
||||
|
||||
async def _postgres_on_start(self,
|
||||
_app: web.Application,
|
||||
loop: ioloop.IOLoop):
|
||||
"""Invoked as a startup step for the application
|
||||
|
||||
This is invoked by the :class:`sprockets.http.app.Application` on start
|
||||
callback mechanism.
|
||||
|
||||
"""
|
||||
if 'POSTGRES_URL' not in os.environ:
|
||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||
return self.stop(loop)
|
||||
if not await self._postgres_connect():
|
||||
LOGGER.critical('PostgreSQL failed to connect, shutting down')
|
||||
return self.stop(loop)
|
||||
|
||||
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
||||
"""Shutdown the Postgres connections and wait for them to close.
|
||||
|
||||
This is invoked by the :class:`sprockets.http.app.Application` shutdown
|
||||
callback mechanism.
|
||||
|
||||
"""
|
||||
if self._postgres_pool is not None:
|
||||
self._postgres_pool.close()
|
||||
await self._postgres_pool.wait_closed()
|
||||
|
||||
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
|
||||
if parsed.scheme.startswith('postgresql+'):
|
||||
host_parts = parsed.hostname.split('.')
|
||||
|
@ -405,86 +510,16 @@ class ApplicationMixin:
|
|||
if not records:
|
||||
raise RuntimeError('No SRV records found')
|
||||
|
||||
netloc = []
|
||||
if parsed.username and not parsed.password:
|
||||
return 'postgresql://{}@{}:{}{}'.format(
|
||||
parsed.username, records[0].host, records[0].port, parsed.path)
|
||||
netloc.append('{}@'.format(parsed.username))
|
||||
elif parsed.username and parsed.password:
|
||||
return 'postgresql://{}:{}@{}:{}{}'.format(
|
||||
parsed.username, parsed.password,
|
||||
records[0].host, records[0].port, parsed.path)
|
||||
return 'postgresql://{}:{}{}'.format(
|
||||
records[0].host, records[0].port, parsed.path)
|
||||
|
||||
async def _postgres_setup(self,
|
||||
_app: web.Application,
|
||||
loop: ioloop.IOLoop) -> None:
|
||||
"""Setup the Postgres pool of connections and log if there is an error.
|
||||
|
||||
This is invoked by the :class:`sprockets.http.app.Application` on start
|
||||
callback mechanism.
|
||||
|
||||
"""
|
||||
if 'POSTGRES_URL' not in os.environ:
|
||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||
return self.stop(loop)
|
||||
|
||||
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
||||
if parsed.scheme.endswith('+srv'):
|
||||
try:
|
||||
url = await self._postgres_url_from_srv(parsed)
|
||||
except RuntimeError as error:
|
||||
LOGGER.critical(str(error))
|
||||
return self.stop(loop)
|
||||
else:
|
||||
url = os.environ['POSTGRES_URL']
|
||||
|
||||
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
|
||||
self._postgres_pool = pool.Pool(
|
||||
url,
|
||||
maxsize=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_MAX_POOL_SIZE',
|
||||
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
|
||||
minsize=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_MIN_POOL_SIZE',
|
||||
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
|
||||
timeout=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_CONNECT_TIMEOUT',
|
||||
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
|
||||
enable_hstore=util.strtobool(
|
||||
os.environ.get(
|
||||
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
|
||||
enable_json=util.strtobool(
|
||||
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
|
||||
enable_uuid=util.strtobool(
|
||||
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
|
||||
echo=False,
|
||||
on_connect=None,
|
||||
pool_recycle=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_CONNECTION_TTL',
|
||||
DEFAULT_POSTGRES_CONNECTION_TTL)))
|
||||
try:
|
||||
async with self._postgres_pool._cond:
|
||||
await self._postgres_pool._fill_free_pool(False)
|
||||
except (psycopg2.OperationalError,
|
||||
psycopg2.Error) as error: # pragma: nocover
|
||||
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
||||
error)
|
||||
LOGGER.debug('Connected to Postgres')
|
||||
|
||||
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
||||
"""Shutdown the Postgres connections and wait for them to close.
|
||||
|
||||
This is invoked by the :class:`sprockets.http.app.Application` shutdown
|
||||
callback mechanism.
|
||||
|
||||
"""
|
||||
if self._postgres_pool is not None:
|
||||
self._postgres_pool.close()
|
||||
await self._postgres_pool.wait_closed()
|
||||
netloc.append('{}:{}@'.format(parsed.username, parsed.password))
|
||||
netloc.append(','.join([
|
||||
'{}:{}'.format(r.host, r.port) for r in records]))
|
||||
return parse.urlunparse(
|
||||
('postgresql', ''.join(netloc), parsed.path,
|
||||
parsed.params, parsed.query, ''))
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_srv(hostname: str) \
|
||||
|
|
93
tests.py
93
tests.py
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
|
@ -7,17 +8,21 @@ import unittest
|
|||
import uuid
|
||||
from urllib import parse
|
||||
|
||||
import aiopg
|
||||
import asynctest
|
||||
import psycopg2
|
||||
import pycares
|
||||
from asynctest import mock
|
||||
from psycopg2 import errors
|
||||
from sprockets.http import app, testing
|
||||
from tornado import ioloop, web
|
||||
from tornado import ioloop, testing as ttesting, web
|
||||
|
||||
import sprockets_postgres
|
||||
|
||||
|
||||
test_postgres_cursor_oer_invocation = 0
|
||||
|
||||
|
||||
class RequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
"""Base RequestHandler for test endpoints"""
|
||||
|
@ -223,6 +228,11 @@ class Application(sprockets_postgres.ApplicationMixin,
|
|||
|
||||
class TestCase(testing.SprocketsHttpTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
self.app._postgres_connected.wait())
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open('build/test-environment') as f:
|
||||
|
@ -267,6 +277,12 @@ class RequestHandlerMixinTestCase(TestCase):
|
|||
self.assertEqual(response.code, 503)
|
||||
self.assertFalse(json.loads(response.body)['available'])
|
||||
|
||||
def test_postgres_status_not_connected(self):
|
||||
self.app._postgres_connected.clear()
|
||||
response = self.fetch('/status')
|
||||
self.assertEqual(response.code, 503)
|
||||
self.assertFalse(json.loads(response.body)['available'])
|
||||
|
||||
@mock.patch('aiopg.cursor.Cursor.execute')
|
||||
def test_postgres_status_error(self, execute):
|
||||
execute.side_effect = asyncio.TimeoutError()
|
||||
|
@ -384,6 +400,55 @@ class RequestHandlerMixinTestCase(TestCase):
|
|||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 503)
|
||||
|
||||
def test_postgres_cursor_operational_error_reconnects(self):
|
||||
original = aiopg.connection.Connection.cursor
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_cursor(self, name=None, cursor_factory=None,
|
||||
scrollable=None, withhold=False, timeout=None):
|
||||
global test_postgres_cursor_oer_invocation
|
||||
|
||||
test_postgres_cursor_oer_invocation += 1
|
||||
if test_postgres_cursor_oer_invocation == 1:
|
||||
raise psycopg2.OperationalError()
|
||||
async with original(self, name, cursor_factory, scrollable,
|
||||
withhold, timeout) as value:
|
||||
yield value
|
||||
|
||||
aiopg.connection.Connection.cursor = mock_cursor
|
||||
|
||||
with mock.patch.object(self.app, '_postgres_connect') as connect:
|
||||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(json.loads(response.body)['value'], '1')
|
||||
connect.assert_called_once()
|
||||
|
||||
aiopg.connection.Connection.cursor = original
|
||||
|
||||
@mock.patch('aiopg.connection.Connection.cursor')
|
||||
def test_postgres_cursor_raises(self, cursor):
|
||||
cursor.side_effect = psycopg2.OperationalError()
|
||||
with mock.patch.object(self.app, '_postgres_connect') as connect:
|
||||
connect.return_value = False
|
||||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 503)
|
||||
connect.assert_called_once()
|
||||
|
||||
@ttesting.gen_test()
|
||||
@mock.patch('aiopg.connection.Connection.cursor')
|
||||
async def test_postgres_cursor_failure_concurrency(self, cursor):
|
||||
cursor.side_effect = psycopg2.OperationalError()
|
||||
|
||||
def on_error(*args):
|
||||
return RuntimeError
|
||||
|
||||
async def invoke_cursor():
|
||||
async with self.app.postgres_connector(on_error) as connector:
|
||||
await connector.execute('SELECT 1')
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
await asyncio.gather(invoke_cursor(), invoke_cursor())
|
||||
|
||||
|
||||
class TransactionTestCase(TestCase):
|
||||
|
||||
|
@ -464,9 +529,9 @@ class SRVTestCase(asynctest.TestCase):
|
|||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||
resolve_srv.return_value = []
|
||||
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
|
||||
await obj._postgres_setup(obj, loop)
|
||||
await obj._postgres_on_start(obj, loop)
|
||||
stop.assert_called_once_with(loop)
|
||||
critical.assert_called_once_with('No SRV records found')
|
||||
critical.assert_any_call('No SRV records found')
|
||||
|
||||
@mock.patch('sprockets.http.app.Application.stop')
|
||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||
|
@ -476,9 +541,9 @@ class SRVTestCase(asynctest.TestCase):
|
|||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||
resolve_srv.return_value = []
|
||||
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
|
||||
await obj._postgres_setup(obj, loop)
|
||||
await obj._postgres_on_start(obj, loop)
|
||||
stop.assert_called_once_with(loop)
|
||||
critical.assert_called_once_with('No SRV records found')
|
||||
critical.assert_any_call('No SRV records found')
|
||||
|
||||
@mock.patch('sprockets.http.app.Application.stop')
|
||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||
|
@ -486,9 +551,9 @@ class SRVTestCase(asynctest.TestCase):
|
|||
obj = Application()
|
||||
loop = ioloop.IOLoop.current()
|
||||
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
|
||||
await obj._postgres_setup(obj, loop)
|
||||
await obj._postgres_on_start(obj, loop)
|
||||
stop.assert_called_once_with(loop)
|
||||
critical.assert_called_once_with(
|
||||
critical.assert_any_call(
|
||||
'Unsupported URI Scheme: postgres+srv')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
|
@ -504,7 +569,8 @@ class SRVTestCase(asynctest.TestCase):
|
|||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('bar.baz', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
||||
self.assertEqual(
|
||||
url, 'postgresql://foo@foo1:5432,foo3:6432,foo2:5432/qux')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_postgresql_url_from_srv_variation_1(self, query):
|
||||
|
@ -518,7 +584,7 @@ class SRVTestCase(asynctest.TestCase):
|
|||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
||||
self.assertEqual(url, 'postgresql://foo@foo1:5432,foo2:5432/qux')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_postgresql_url_from_srv_variation_2(self, query):
|
||||
|
@ -526,13 +592,14 @@ class SRVTestCase(asynctest.TestCase):
|
|||
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
|
||||
future = asyncio.Future()
|
||||
future.set_result([
|
||||
SRV('foo2', 5432, 2, 0, 32),
|
||||
SRV('foo1', 5432, 1, 0, 32)
|
||||
SRV('foo2', 5432, 1, 0, 32),
|
||||
SRV('foo1', 5432, 2, 0, 32)
|
||||
])
|
||||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie')
|
||||
self.assertEqual(
|
||||
url, 'postgresql://foo:bar@foo2:5432,foo1:5432/corgie')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_postgresql_url_from_srv_variation_3(self, query):
|
||||
|
@ -547,7 +614,7 @@ class SRVTestCase(asynctest.TestCase):
|
|||
query.return_value = future
|
||||
url = await obj._postgres_url_from_srv(parsed)
|
||||
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
|
||||
self.assertEqual(url, 'postgresql://foo3:5432/baz')
|
||||
self.assertEqual(url, 'postgresql://foo3:5432,foo1:5432,foo2:5432/baz')
|
||||
|
||||
@mock.patch('aiodns.DNSResolver.query')
|
||||
async def test_resolve_srv_sorted(self, query):
|
||||
|
|
Loading…
Reference in a new issue