diff --git a/VERSION b/VERSION index 26aaba0..f0bb29e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.2.0 +1.3.0 diff --git a/docs/configuration.rst b/docs/configuration.rst index d08c29b..4d9ef97 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -27,10 +27,10 @@ details the configuration options and their defaults. +---------------------------------+--------------------------------------------------+-----------+ If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be -performed and the lowest priority record with the highest weight will be selected -for connecting to Postgres. +performed and the URL will be constructed containing all host and port combinations +in priority and weighted order, utilizing `libpq's supoprt `_ +for multiple hosts in a URL. AWS's ECS service discovery does not follow the SRV standard, but creates SRV records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be -performed using the correct format for ECS service discovery. The lowest priority -record with the highest weight will be selected for connecting to Postgres. +performed using the correct format for ECS service discovery. diff --git a/sprockets_postgres.py b/sprockets_postgres.py index fbe3bac..4895fb1 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -276,10 +276,9 @@ class ConnectionException(Exception): class ApplicationMixin: - """ - :class:`sprockets.http.app.Application` / :class:`tornado.web.Application` - mixin for handling the connection to Postgres and exporting functions for - querying the database, getting the status, and proving a cursor. + """:class:`sprockets.http.app.Application` mixin for handling the + connection to Postgres and exporting functions for querying the database, + getting the status, and proving a cursor. Automatically creates and shuts down :class:`aiopg.Pool` on startup and shutdown by installing `on_start` and `shutdown` callbacks into the @@ -291,7 +290,10 @@ class ApplicationMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._postgres_pool: typing.Optional[pool.Pool] = None - self.runner_callbacks['on_start'].append(self._postgres_setup) + self._postgres_connected = asyncio.Event() + self._postgres_reconnect = asyncio.Lock() + self._postgres_srv: bool = False + self.runner_callbacks['on_start'].append(self._postgres_on_start) self.runner_callbacks['shutdown'].append(self._postgres_shutdown) @contextlib.asynccontextmanager @@ -299,7 +301,8 @@ class ApplicationMixin: on_error: typing.Callable, on_duration: typing.Optional[ typing.Callable] = None, - timeout: Timeout = None) \ + timeout: Timeout = None, + _attempt: int = 1) \ -> typing.AsyncContextManager[PostgresConnector]: """Asynchronous :ref:`context-manager ` that returns a :class:`~sprockets_postgres.PostgresConnector` instance @@ -338,7 +341,19 @@ class ApplicationMixin: yield PostgresConnector( cursor, on_error, on_duration, timeout) except (asyncio.TimeoutError, psycopg2.Error) as err: - exc = on_error('postgres_connector', ConnectionException(str(err))) + message = str(err) + if isinstance(err, psycopg2.OperationalError) and _attempt == 1: + LOGGER.critical('Disconnected from Postgres: %s', err) + if not self._postgres_reconnect.locked(): + async with self._postgres_reconnect: + if await self._postgres_connect(): + async with self.postgres_connector( + on_error, on_duration, timeout, + _attempt + 1) as connector: + yield connector + return + message = 'disconnected' + exc = on_error('postgres_connector', ConnectionException(message)) if exc: raise exc else: # postgres_status.on_error does not return an exception @@ -370,6 +385,13 @@ class ApplicationMixin: } """ + if not self._postgres_connected.is_set(): + return { + 'available': False, + 'pool_size': 0, + 'pool_free': 0 + } + LOGGER.debug('Querying postgres status') query_error = asyncio.Event() @@ -390,6 +412,89 @@ class ApplicationMixin: 'pool_free': self._postgres_pool.freesize } + async def _postgres_connect(self) -> bool: + """Setup the Postgres pool of connections""" + self._postgres_connected.clear() + + parsed = parse.urlparse(os.environ['POSTGRES_URL']) + if parsed.scheme.endswith('+srv'): + self._postgres_srv = True + try: + url = await self._postgres_url_from_srv(parsed) + except RuntimeError as error: + LOGGER.critical(str(error)) + return False + else: + url = os.environ['POSTGRES_URL'] + + if self._postgres_pool: + self._postgres_pool.close() + + LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL']) + try: + self._postgres_pool = await pool.Pool.from_pool_fill( + url, + maxsize=int( + os.environ.get( + 'POSTGRES_MAX_POOL_SIZE', + DEFAULT_POSTGRES_MAX_POOL_SIZE)), + minsize=int( + os.environ.get( + 'POSTGRES_MIN_POOL_SIZE', + DEFAULT_POSTGRES_MIN_POOL_SIZE)), + timeout=int( + os.environ.get( + 'POSTGRES_CONNECT_TIMEOUT', + DEFAULT_POSTGRES_CONNECTION_TIMEOUT)), + enable_hstore=util.strtobool( + os.environ.get( + 'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)), + enable_json=util.strtobool( + os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)), + enable_uuid=util.strtobool( + os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)), + echo=False, + on_connect=None, + pool_recycle=int( + os.environ.get( + 'POSTGRES_CONNECTION_TTL', + DEFAULT_POSTGRES_CONNECTION_TTL))) + except (psycopg2.OperationalError, + psycopg2.Error) as error: # pragma: nocover + LOGGER.warning('Error connecting to PostgreSQL on startup: %s', + error) + return False + self._postgres_connected.set() + LOGGER.debug('Connected to Postgres') + return True + + async def _postgres_on_start(self, + _app: web.Application, + loop: ioloop.IOLoop): + """Invoked as a startup step for the application + + This is invoked by the :class:`sprockets.http.app.Application` on start + callback mechanism. + + """ + if 'POSTGRES_URL' not in os.environ: + LOGGER.critical('Missing POSTGRES_URL environment variable') + return self.stop(loop) + if not await self._postgres_connect(): + LOGGER.critical('PostgreSQL failed to connect, shutting down') + return self.stop(loop) + + async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None: + """Shutdown the Postgres connections and wait for them to close. + + This is invoked by the :class:`sprockets.http.app.Application` shutdown + callback mechanism. + + """ + if self._postgres_pool is not None: + self._postgres_pool.close() + await self._postgres_pool.wait_closed() + async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str: if parsed.scheme.startswith('postgresql+'): host_parts = parsed.hostname.split('.') @@ -405,86 +510,16 @@ class ApplicationMixin: if not records: raise RuntimeError('No SRV records found') + netloc = [] if parsed.username and not parsed.password: - return 'postgresql://{}@{}:{}{}'.format( - parsed.username, records[0].host, records[0].port, parsed.path) + netloc.append('{}@'.format(parsed.username)) elif parsed.username and parsed.password: - return 'postgresql://{}:{}@{}:{}{}'.format( - parsed.username, parsed.password, - records[0].host, records[0].port, parsed.path) - return 'postgresql://{}:{}{}'.format( - records[0].host, records[0].port, parsed.path) - - async def _postgres_setup(self, - _app: web.Application, - loop: ioloop.IOLoop) -> None: - """Setup the Postgres pool of connections and log if there is an error. - - This is invoked by the :class:`sprockets.http.app.Application` on start - callback mechanism. - - """ - if 'POSTGRES_URL' not in os.environ: - LOGGER.critical('Missing POSTGRES_URL environment variable') - return self.stop(loop) - - parsed = parse.urlparse(os.environ['POSTGRES_URL']) - if parsed.scheme.endswith('+srv'): - try: - url = await self._postgres_url_from_srv(parsed) - except RuntimeError as error: - LOGGER.critical(str(error)) - return self.stop(loop) - else: - url = os.environ['POSTGRES_URL'] - - LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL']) - self._postgres_pool = pool.Pool( - url, - maxsize=int( - os.environ.get( - 'POSTGRES_MAX_POOL_SIZE', - DEFAULT_POSTGRES_MAX_POOL_SIZE)), - minsize=int( - os.environ.get( - 'POSTGRES_MIN_POOL_SIZE', - DEFAULT_POSTGRES_MIN_POOL_SIZE)), - timeout=int( - os.environ.get( - 'POSTGRES_CONNECT_TIMEOUT', - DEFAULT_POSTGRES_CONNECTION_TIMEOUT)), - enable_hstore=util.strtobool( - os.environ.get( - 'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)), - enable_json=util.strtobool( - os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)), - enable_uuid=util.strtobool( - os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)), - echo=False, - on_connect=None, - pool_recycle=int( - os.environ.get( - 'POSTGRES_CONNECTION_TTL', - DEFAULT_POSTGRES_CONNECTION_TTL))) - try: - async with self._postgres_pool._cond: - await self._postgres_pool._fill_free_pool(False) - except (psycopg2.OperationalError, - psycopg2.Error) as error: # pragma: nocover - LOGGER.warning('Error connecting to PostgreSQL on startup: %s', - error) - LOGGER.debug('Connected to Postgres') - - async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None: - """Shutdown the Postgres connections and wait for them to close. - - This is invoked by the :class:`sprockets.http.app.Application` shutdown - callback mechanism. - - """ - if self._postgres_pool is not None: - self._postgres_pool.close() - await self._postgres_pool.wait_closed() + netloc.append('{}:{}@'.format(parsed.username, parsed.password)) + netloc.append(','.join([ + '{}:{}'.format(r.host, r.port) for r in records])) + return parse.urlunparse( + ('postgresql', ''.join(netloc), parsed.path, + parsed.params, parsed.query, '')) @staticmethod async def _resolve_srv(hostname: str) \ diff --git a/tests.py b/tests.py index 4c2b210..d6b3679 100644 --- a/tests.py +++ b/tests.py @@ -1,5 +1,6 @@ import asyncio import collections +import contextlib import json import os import typing @@ -7,17 +8,21 @@ import unittest import uuid from urllib import parse +import aiopg import asynctest import psycopg2 import pycares from asynctest import mock from psycopg2 import errors from sprockets.http import app, testing -from tornado import ioloop, web +from tornado import ioloop, testing as ttesting, web import sprockets_postgres +test_postgres_cursor_oer_invocation = 0 + + class RequestHandler(sprockets_postgres.RequestHandlerMixin, web.RequestHandler): """Base RequestHandler for test endpoints""" @@ -223,6 +228,11 @@ class Application(sprockets_postgres.ApplicationMixin, class TestCase(testing.SprocketsHttpTestCase): + def setUp(self): + super().setUp() + asyncio.get_event_loop().run_until_complete( + self.app._postgres_connected.wait()) + @classmethod def setUpClass(cls): with open('build/test-environment') as f: @@ -267,6 +277,12 @@ class RequestHandlerMixinTestCase(TestCase): self.assertEqual(response.code, 503) self.assertFalse(json.loads(response.body)['available']) + def test_postgres_status_not_connected(self): + self.app._postgres_connected.clear() + response = self.fetch('/status') + self.assertEqual(response.code, 503) + self.assertFalse(json.loads(response.body)['available']) + @mock.patch('aiopg.cursor.Cursor.execute') def test_postgres_status_error(self, execute): execute.side_effect = asyncio.TimeoutError() @@ -384,6 +400,55 @@ class RequestHandlerMixinTestCase(TestCase): response = self.fetch('/execute?value=1') self.assertEqual(response.code, 503) + def test_postgres_cursor_operational_error_reconnects(self): + original = aiopg.connection.Connection.cursor + + @contextlib.asynccontextmanager + async def mock_cursor(self, name=None, cursor_factory=None, + scrollable=None, withhold=False, timeout=None): + global test_postgres_cursor_oer_invocation + + test_postgres_cursor_oer_invocation += 1 + if test_postgres_cursor_oer_invocation == 1: + raise psycopg2.OperationalError() + async with original(self, name, cursor_factory, scrollable, + withhold, timeout) as value: + yield value + + aiopg.connection.Connection.cursor = mock_cursor + + with mock.patch.object(self.app, '_postgres_connect') as connect: + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 200) + self.assertEqual(json.loads(response.body)['value'], '1') + connect.assert_called_once() + + aiopg.connection.Connection.cursor = original + + @mock.patch('aiopg.connection.Connection.cursor') + def test_postgres_cursor_raises(self, cursor): + cursor.side_effect = psycopg2.OperationalError() + with mock.patch.object(self.app, '_postgres_connect') as connect: + connect.return_value = False + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 503) + connect.assert_called_once() + + @ttesting.gen_test() + @mock.patch('aiopg.connection.Connection.cursor') + async def test_postgres_cursor_failure_concurrency(self, cursor): + cursor.side_effect = psycopg2.OperationalError() + + def on_error(*args): + return RuntimeError + + async def invoke_cursor(): + async with self.app.postgres_connector(on_error) as connector: + await connector.execute('SELECT 1') + + with self.assertRaises(RuntimeError): + await asyncio.gather(invoke_cursor(), invoke_cursor()) + class TransactionTestCase(TestCase): @@ -464,9 +529,9 @@ class SRVTestCase(asynctest.TestCase): with mock.patch.object(obj, '_resolve_srv') as resolve_srv: resolve_srv.return_value = [] os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz' - await obj._postgres_setup(obj, loop) + await obj._postgres_on_start(obj, loop) stop.assert_called_once_with(loop) - critical.assert_called_once_with('No SRV records found') + critical.assert_any_call('No SRV records found') @mock.patch('sprockets.http.app.Application.stop') @mock.patch('sprockets_postgres.LOGGER.critical') @@ -476,9 +541,9 @@ class SRVTestCase(asynctest.TestCase): with mock.patch.object(obj, '_resolve_srv') as resolve_srv: resolve_srv.return_value = [] os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz' - await obj._postgres_setup(obj, loop) + await obj._postgres_on_start(obj, loop) stop.assert_called_once_with(loop) - critical.assert_called_once_with('No SRV records found') + critical.assert_any_call('No SRV records found') @mock.patch('sprockets.http.app.Application.stop') @mock.patch('sprockets_postgres.LOGGER.critical') @@ -486,9 +551,9 @@ class SRVTestCase(asynctest.TestCase): obj = Application() loop = ioloop.IOLoop.current() os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz' - await obj._postgres_setup(obj, loop) + await obj._postgres_on_start(obj, loop) stop.assert_called_once_with(loop) - critical.assert_called_once_with( + critical.assert_any_call( 'Unsupported URI Scheme: postgres+srv') @mock.patch('aiodns.DNSResolver.query') @@ -504,7 +569,8 @@ class SRVTestCase(asynctest.TestCase): query.return_value = future url = await obj._postgres_url_from_srv(parsed) query.assert_called_once_with('bar.baz', 'SRV') - self.assertEqual(url, 'postgresql://foo@foo1:5432/qux') + self.assertEqual( + url, 'postgresql://foo@foo1:5432,foo3:6432,foo2:5432/qux') @mock.patch('aiodns.DNSResolver.query') async def test_postgresql_url_from_srv_variation_1(self, query): @@ -518,7 +584,7 @@ class SRVTestCase(asynctest.TestCase): query.return_value = future url = await obj._postgres_url_from_srv(parsed) query.assert_called_once_with('_bar._postgresql.baz', 'SRV') - self.assertEqual(url, 'postgresql://foo@foo1:5432/qux') + self.assertEqual(url, 'postgresql://foo@foo1:5432,foo2:5432/qux') @mock.patch('aiodns.DNSResolver.query') async def test_postgresql_url_from_srv_variation_2(self, query): @@ -526,13 +592,14 @@ class SRVTestCase(asynctest.TestCase): parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie') future = asyncio.Future() future.set_result([ - SRV('foo2', 5432, 2, 0, 32), - SRV('foo1', 5432, 1, 0, 32) + SRV('foo2', 5432, 1, 0, 32), + SRV('foo1', 5432, 2, 0, 32) ]) query.return_value = future url = await obj._postgres_url_from_srv(parsed) query.assert_called_once_with('_baz._postgresql.qux', 'SRV') - self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie') + self.assertEqual( + url, 'postgresql://foo:bar@foo2:5432,foo1:5432/corgie') @mock.patch('aiodns.DNSResolver.query') async def test_postgresql_url_from_srv_variation_3(self, query): @@ -547,7 +614,7 @@ class SRVTestCase(asynctest.TestCase): query.return_value = future url = await obj._postgres_url_from_srv(parsed) query.assert_called_once_with('_foo._postgresql.bar', 'SRV') - self.assertEqual(url, 'postgresql://foo3:5432/baz') + self.assertEqual(url, 'postgresql://foo3:5432,foo1:5432,foo2:5432/baz') @mock.patch('aiodns.DNSResolver.query') async def test_resolve_srv_sorted(self, query):