Merge pull request #10 from gmr/master

Attempt to fix the reconnection race condition
This commit is contained in:
Andrew Rabert 2021-04-15 11:22:05 -04:00 committed by GitHub
commit d9b2ed5c69
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 15 deletions

View file

@ -1 +1 @@
1.7.0 1.8.0

View file

@ -71,7 +71,7 @@ report_done
cat > build/test-environment<<EOF cat > build/test-environment<<EOF
export ASYNC_TEST_TIMEOUT=5 export ASYNC_TEST_TIMEOUT=5
export POSTGRES_URL=postgresql://postgres@${TEST_HOST}:$(get_exposed_port postgres 5432)/postgres export POSTGRES_URL=postgresql://postgres@${TEST_HOST}:$(get_exposed_port postgres 5432)/postgres?application_name=sprockets_postgres
EOF EOF
printf "\nBootstrap complete\n\nDon't forget to \"source build/test-environment\"\n" printf "\nBootstrap complete\n\nDon't forget to \"source build/test-environment\"\n"

View file

@ -361,18 +361,27 @@ class ApplicationMixin:
as cursor: as cursor:
yield PostgresConnector( yield PostgresConnector(
cursor, on_error, on_duration, timeout) cursor, on_error, on_duration, timeout)
except (asyncio.TimeoutError, psycopg2.OperationalError) as err: except (asyncio.TimeoutError,
psycopg2.OperationalError,
RuntimeError) as err:
if isinstance(err, psycopg2.OperationalError) and _attempt == 1: if isinstance(err, psycopg2.OperationalError) and _attempt == 1:
LOGGER.critical('Disconnected from Postgres: %s', err) LOGGER.critical('Disconnected from Postgres: %s', err)
if not self._postgres_reconnect.locked(): if not self._postgres_reconnect.locked():
async with self._postgres_reconnect: async with self._postgres_reconnect:
LOGGER.info('Reconnecting to Postgres with new Pool')
if await self._postgres_connect(): if await self._postgres_connect():
await self._postgres_connected.wait() try:
async with self.postgres_connector( await asyncio.wait_for(
on_error, on_duration, timeout, self._postgres_connected.wait(),
_attempt + 1) as connector: self._postgres_settings['timeout'])
yield connector except asyncio.TimeoutError as error:
return err = error
else:
async with self.postgres_connector(
on_error, on_duration, timeout,
_attempt + 1) as connector:
yield connector
return
if on_error is None: if on_error is None:
raise ConnectionException(str(err)) raise ConnectionException(str(err))
exc = on_error( exc = on_error(
@ -382,6 +391,12 @@ class ApplicationMixin:
else: # postgres_status.on_error does not return an exception else: # postgres_status.on_error does not return an exception
yield None yield None
@property
def postgres_is_connected(self) -> bool:
"""Returns `True` if Postgres is currently connected"""
return self._postgres_connected is not None \
and self._postgres_connected.is_set()
async def postgres_status(self) -> dict: async def postgres_status(self) -> dict:
"""Invoke from the ``/status`` RequestHandler to check that there is """Invoke from the ``/status`` RequestHandler to check that there is
a Postgres connection handler available and return info about the a Postgres connection handler available and return info about the
@ -408,8 +423,7 @@ class ApplicationMixin:
} }
""" """
if not self._postgres_connected or \ if not self.postgres_is_connected:
not self._postgres_connected.is_set():
return { return {
'available': False, 'available': False,
'pool_size': 0, 'pool_size': 0,
@ -497,12 +511,12 @@ class ApplicationMixin:
else: else:
url = self._postgres_settings['url'] url = self._postgres_settings['url']
if self._postgres_pool:
self._postgres_pool.close()
safe_url = self._obscure_url_password(url) safe_url = self._obscure_url_password(url)
LOGGER.debug('Connecting to %s', safe_url) LOGGER.debug('Connecting to %s', safe_url)
if self._postgres_pool and not self._postgres_pool.closed:
self._postgres_pool.close()
try: try:
self._postgres_pool = await pool.Pool.from_pool_fill( self._postgres_pool = await pool.Pool.from_pool_fill(
url, url,
@ -513,7 +527,7 @@ class ApplicationMixin:
enable_json=self._postgres_settings['enable_json'], enable_json=self._postgres_settings['enable_json'],
enable_uuid=self._postgres_settings['enable_uuid'], enable_uuid=self._postgres_settings['enable_uuid'],
echo=False, echo=False,
on_connect=None, on_connect=self._on_postgres_connect,
pool_recycle=self._postgres_settings['connection_ttl']) pool_recycle=self._postgres_settings['connection_ttl'])
except psycopg2.Error as error: # pragma: nocover except psycopg2.Error as error: # pragma: nocover
LOGGER.warning( LOGGER.warning(
@ -535,6 +549,9 @@ class ApplicationMixin:
url = parse.urlunparse(parsed._replace(netloc=netloc)) url = parse.urlunparse(parsed._replace(netloc=netloc))
return url return url
async def _on_postgres_connect(self, conn):
LOGGER.debug('New postgres connection %s', conn)
async def _postgres_on_start(self, async def _postgres_on_start(self,
_app: web.Application, _app: web.Application,
loop: ioloop.IOLoop): loop: ioloop.IOLoop):
@ -640,6 +657,7 @@ class RequestHandlerMixin:
:rtype: :class:`~sprockets_postgres.QueryResult` :rtype: :class:`~sprockets_postgres.QueryResult`
""" """
self._postgres_connection_check()
async with self.application.postgres_connector( async with self.application.postgres_connector(
self.on_postgres_error, self.on_postgres_error,
self.on_postgres_timing, self.on_postgres_timing,
@ -679,6 +697,7 @@ class RequestHandlerMixin:
:rtype: :class:`~sprockets_postgres.QueryResult` :rtype: :class:`~sprockets_postgres.QueryResult`
""" """
self._postgres_connection_check()
async with self.application.postgres_connector( async with self.application.postgres_connector(
self.on_postgres_error, self.on_postgres_error,
self.on_postgres_timing, self.on_postgres_timing,
@ -726,6 +745,7 @@ class RequestHandlerMixin:
likely be more specific. likely be more specific.
""" """
self._postgres_connection_check()
async with self.application.postgres_connector( async with self.application.postgres_connector(
self.on_postgres_error, self.on_postgres_error,
self.on_postgres_timing, self.on_postgres_timing,
@ -771,6 +791,11 @@ class RequestHandlerMixin:
raise problemdetails.Problem( raise problemdetails.Problem(
status_code=409, title='Unique Violation') status_code=409, title='Unique Violation')
raise web.HTTPError(409, reason='Unique Violation') raise web.HTTPError(409, reason='Unique Violation')
elif isinstance(exc, psycopg2.OperationalError):
if problemdetails:
raise problemdetails.Problem(
status_code=503, title='Database Error')
raise web.HTTPError(503, reason='Database Error')
elif isinstance(exc, psycopg2.Error): elif isinstance(exc, psycopg2.Error):
if problemdetails: if problemdetails:
raise problemdetails.Problem( raise problemdetails.Problem(
@ -801,6 +826,19 @@ class RequestHandlerMixin:
LOGGER.debug('Postgres query %s duration: %s', LOGGER.debug('Postgres query %s duration: %s',
metric_name, duration) metric_name, duration)
def _postgres_connection_check(self):
"""Ensures Postgres is connected, exiting the request in error if not
:raises: problemdetails.Problem
:raises: web.HTTPError
"""
if not self.application.postgres_is_connected:
if problemdetails:
raise problemdetails.Problem(
status_code=503, title='Database Connection Error')
raise web.HTTPError(503, reason='Database Connection Error')
class StatusRequestHandler(web.RequestHandler): class StatusRequestHandler(web.RequestHandler):
"""A RequestHandler that can be used to expose API health or status""" """A RequestHandler that can be used to expose API health or status"""

View file

@ -361,8 +361,38 @@ class PostgresStatusTestCase(asynctest.TestCase):
'pool_free': 0}) 'pool_free': 0})
class ReconnectionTestCast(TestCase):
@ttesting.gen_test
async def test_postgres_reconnect(self):
response = await self.http_client.fetch(self.get_url('/callproc'))
self.assertEqual(response.code, 200)
self.assertIsInstance(
uuid.UUID(json.loads(response.body)['value']), uuid.UUID)
# Force close all open connections for tests
conn = await aiopg.connect(os.environ['POSTGRES_URL'].split('?')[0])
cursor = await conn.cursor()
await cursor.execute(
'SELECT pg_terminate_backend(pid)'
' FROM pg_stat_activity'
" WHERE application_name = 'sprockets_postgres'")
await cursor.fetchall()
await asyncio.sleep(1)
response = await self.http_client.fetch(
self.get_url('/callproc'), raise_error=False)
self.assertEqual(response.code, 200)
conn.close()
class RequestHandlerMixinTestCase(TestCase): class RequestHandlerMixinTestCase(TestCase):
def test_postgres_connected(self):
response = self.fetch('/status')
data = json.loads(response.body)
self.assertEqual(data['status'], 'ok')
self.assertTrue(self.app.postgres_is_connected)
def test_postgres_status(self): def test_postgres_status(self):
response = self.fetch('/status') response = self.fetch('/status')
data = json.loads(response.body) data = json.loads(response.body)