diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 9063985..252fb46 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -340,26 +340,22 @@ class ApplicationMixin: timeout=timeout) as cursor: yield PostgresConnector( cursor, on_error, on_duration, timeout) - except (asyncio.TimeoutError, psycopg2.Error) as err: - message = str(err) + except (asyncio.TimeoutError, psycopg2.OperationalError) as err: if isinstance(err, psycopg2.OperationalError) and _attempt == 1: LOGGER.critical('Disconnected from Postgres: %s', err) - retry = True if not self._postgres_reconnect.locked(): async with self._postgres_reconnect: - retry = await self._postgres_connect() - if retry: - await self._postgres_connected.wait() - async with self.postgres_connector( - on_error, on_duration, timeout, - _attempt + 1) as connector: - yield connector - return - message = 'disconnected' - exc = on_error('postgres_connector', ConnectionException(message)) + if await self._postgres_connect(): + await self._postgres_connected.wait() + async with self.postgres_connector( + on_error, on_duration, timeout, + _attempt + 1) as connector: + yield connector + return + exc = on_error('postgres_connector', ConnectionException(str(err))) if exc: raise exc - else: # postgres_status.on_error does not return an exception + else: # postgres_status.on_error does not return an exception yield None async def postgres_status(self) -> dict: diff --git a/tests.py b/tests.py index a6bdcce..64f221c 100644 --- a/tests.py +++ b/tests.py @@ -216,6 +216,51 @@ class TransactionRequestHandler(RequestHandler): 'user': self.cast_data(user.row)}) +class TimeoutErrorRequestHandler(RequestHandler): + + GET_SQL = 'SELECT 1;' + + async def get(self): + await self.postgres_execute(self.GET_SQL) + raise web.HTTPError(500, 'This should have failed') + + def _on_postgres_error(self, + metric_name: str, + exc: Exception) -> typing.Optional[Exception]: + """Override for different error handling behaviors + + Return an exception if you would like for it to be raised, or swallow + it here. + + """ + if isinstance(exc, asyncio.TimeoutError): + raise web.HTTPError(418) + return exc + + +class UnhandledExceptionRequestHandler(RequestHandler): + + GET_SQL = 'SELECT 100 / 0;' + + async def get(self): + try: + await self.postgres_execute(self.GET_SQL) + except psycopg2.DataError: + raise web.HTTPError(422) + raise web.HTTPError(500, 'This should have failed') + + def _on_postgres_error(self, + metric_name: str, + exc: Exception) -> typing.Optional[Exception]: + """Override for different error handling behaviors + + Return an exception if you would like for it to be raised, or swallow + it here. + + """ + return exc + + class Application(sprockets_postgres.ApplicationMixin, app.Application): @@ -257,8 +302,10 @@ class TestCase(testing.SprocketsHttpTestCase): web.url('/no-error', NoErrorRequestHandler), web.url('/no-row', NoRowRequestHandler), web.url('/status', StatusRequestHandler), + web.url('/timeout-error', TimeoutErrorRequestHandler), web.url('/transaction', TransactionRequestHandler), - web.url('/transaction/(?P.*)', TransactionRequestHandler) + web.url('/transaction/(?P.*)', TransactionRequestHandler), + web.url('/unhandled-exception', UnhandledExceptionRequestHandler) ]) return self.app @@ -451,6 +498,16 @@ class RequestHandlerMixinTestCase(TestCase): with self.assertRaises(RuntimeError): await asyncio.gather(invoke_cursor(), invoke_cursor()) + @mock.patch('aiopg.cursor.Cursor.execute') + def test_timeout_error_when_overriding_on_postgres_error(self, execute): + execute.side_effect = asyncio.TimeoutError + response = self.fetch('/timeout-error') + self.assertEqual(response.code, 418) + + def test_unhandled_exception_in_on_postgres_error(self): + response = self.fetch('/unhandled-exception') + self.assertEqual(response.code, 422) + class TransactionTestCase(TestCase):