diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 48586f2..126f4f8 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -373,7 +373,10 @@ class ApplicationMixin: _attempt + 1) as connector: yield connector return - exc = on_error('postgres_connector', ConnectionException(str(err))) + if on_error is None: + raise ConnectionException(str(err)) + exc = on_error( + 'postgres_connector', ConnectionException(str(err))) if exc: raise exc else: # postgres_status.on_error does not return an exception diff --git a/tests.py b/tests.py index c4bdb8a..56b6152 100644 --- a/tests.py +++ b/tests.py @@ -168,6 +168,18 @@ class NoErrorRequestHandler(ErrorRequestHandler): return None +class NoMixinRequestHandler(web.RequestHandler): + + async def get(self): + try: + async with self.application.postgres_connector() as connector: + await connector.execute('SELECT 1', {}) + except sprockets_postgres.ConnectionException: + raise web.HTTPError(503) + else: + raise web.HTTPError(418) + + class NoRowRequestHandler(RequestHandler): GET_SQL = """\ @@ -257,7 +269,7 @@ class UnhandledExceptionRequestHandler(RequestHandler): await self.postgres_execute(self.GET_SQL) except psycopg2.DataError: raise web.HTTPError(422) - raise web.HTTPError(500, 'This should have failed') + raise web.HTTPError(418, 'This should have failed') def on_postgres_error(self, metric_name: str, @@ -271,6 +283,14 @@ class UnhandledExceptionRequestHandler(RequestHandler): return exc +class DontRaiseRequestHandler(UnhandledExceptionRequestHandler): + + def on_postgres_error(self, + metric_name: str, + exc: Exception) -> typing.Optional[Exception]: + return None + + class Application(sprockets_postgres.ApplicationMixin, app.Application): @@ -303,6 +323,7 @@ class TestCase(testing.SprocketsHttpTestCase): self.app = Application(handlers=[ web.url('/callproc', CallprocRequestHandler), web.url('/count', CountRequestHandler), + web.url('/dont-raise', DontRaiseRequestHandler), web.url('/error', ErrorRequestHandler), web.url('/error-passthrough', ErrorPassthroughRequestHandler), web.url('/execute', ExecuteRequestHandler), @@ -310,6 +331,7 @@ class TestCase(testing.SprocketsHttpTestCase): web.url('/metrics-mixin', MetricsMixinRequestHandler), web.url('/multi-row', MultiRowRequestHandler), web.url('/no-error', NoErrorRequestHandler), + web.url('/no-mixin', NoMixinRequestHandler), web.url('/no-row', NoRowRequestHandler), web.url('/pdexecute', ProblemDetailsExecuteRequestHandler), web.url('/row-count-no-rows', RowCountNoRowsRequestHandler), @@ -529,7 +551,7 @@ class RequestHandlerMixinTestCase(TestCase): self.assertIsNone(json.loads(response.body)['value']) -class RequestHandlerMixinHTTPErrorTestCase(TestCase): +class HTTPErrorTestCase(TestCase): def setUp(self): super().setUp() @@ -574,6 +596,25 @@ class RequestHandlerMixinHTTPErrorTestCase(TestCase): response = self.fetch('/execute?value=1') self.assertEqual(response.code, 503) + def test_on_error_no_exception_branch(self): + response = self.fetch('/dont-raise') + self.assertEqual(response.code, 418) + + +class NoMixinTestCase(TestCase): + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_cursor_raises(self, execute): + execute.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/no-mixin') + self.assertEqual(response.code, 500) + + @mock.patch('aiopg.pool.Pool.acquire') + def test_postgres_status_connect_error(self, acquire): + acquire.side_effect = asyncio.TimeoutError() + response = self.fetch('/no-mixin') + self.assertEqual(response.code, 503) + class TransactionTestCase(TestCase):